1 Introduction
Continual learning aims to incrementally learn a sequence of tasks (chen2018lifelong). Each task consists of a set of classes to be learned together. A major challenge of CL is catastrophic forgetting (CF). Although a large number of CL techniques have been proposed, they are mainly empirical. Limited theoretical research has done on how to solve CL. This paper performs such a theoretical study about the necessary and sufficient conditions for effective CL. There are two main CL settings that have been extensively studied: class incremental learning (CIL) and task incremental learning (TIL) (van2019three)
. In CIL, the learning process builds a single classifier for all tasks/classes learned so far. In testing, a test instance from any class may be presented for the model to classify. No prior task information (e.g., task-id) of the test instance is provided. Formally, CIL is defined as follows.
Class incremental learning (CIL). CIL learns a sequence of tasks, . Each task has a training dataset , where is the number of data samples in task , and is an input sample and (the set of all classes of task ) is its class label. All ’s are disjoint () and . The goal of CIL is to construct a single predictive function or classifier that can identify the class label of each given test instance .
In the TIL setup, each task is a separate classification problem. For example, one task could be to classify different breeds of dogs and another task could be to classify different types of animals (the tasks may not be disjoint). One model is built for each task in a shared network. In testing, the task-id of each test instance is provided and the system uses only the specific model for the task (dog or animal classification) to classify the test instance. Formally, TIL is defined as follows.
Task incremental learning (TIL). TIL learns a sequence of tasks, . Each task has a training dataset , where is the number of data samples in task , and is an input sample and is its class label. The goal of TIL is to construct a predictor to identify the class label for (the given test instance from task ).
Several techniques are available to effectively overcome CF for TIL (with almost no CF) Serra2018overcoming; supsup2020. However, CIL remains to be highly challenging due to the additional problem of Inter-task Class Separation (ICS) (establishing decision boundaries between the classes of the new task and the classes of the previous tasks) because the previous task data are not accessible. Before discussing the proposed work, we recall the closed-world
assumption made by traditional machine learning, i.e.,
the classes seen in testing must have been seen in training chen2018lifelong; liu2021self. However, in many applications, there are unknowns in testing, which is called the open world setting chen2018lifelong; liu2021self. In open world learning, the training (or known) classes are called in-distribution (IND) classes. A classifier built for the open world can (1) classify test instances of training/IND classes to their respective classes, which is called IND prediction, and (2) detect test instances that do not belong to any of the IND/known classes but some unknown or out-of-distribution (OOD) classes, which is called OOD detection. Many OOD detection algorithms can perform both IND prediction and OOD detection tack2020csi; liang2018enhancing; esmaeilpour2022zero; wang2022omg. The commonality of OOD detection and CL is that they both need to consider future unknowns.This paper conducts a theoretical study of CIL, which is applicable to any CIL classification model. Instead of focusing on the traditional PAC generalization bound pentina2014pac; karakida2022learning, we focus on how to solve the CIL problem. We first decompose the CIL problem into two sub-problems in a probabilistic framework: Within-task Prediction (WP) and Task-id Prediction (TP). WP means that the prediction for a test instance is only done within the classes of the task to which the test instance belongs, which is basically the TIL problem. TP predicts the task-id. TP is needed because in CIL, task-id is not provided in testing. This paper then proves based on the popular cross-entropy loss that (1) the CIL performance is bounded by WP and TP performances, and (2) TP and task OOD detection performance bound each other (which connects CL and OOD detection). The key result is that regardless of whether WP and TP or OOD detection are defined explicitly or implicitly by a CIL algorithm, good WP and good TP or OOD detection are necessary and sufficient conditions for good CIL performances. This result is applicable to both batch/offline and online CIL and to CIL problems with blurry task boundaries. The intuition is also quite simple because if a CIL model is perfect at detecting OOD samples for each task (which solves the ICS problem), then CIL is reduced to WP.
This theoretical result provides a principled guidance for solving the CIL problem, i.e., to help design better CIL algorithms that can achieve strong WP and TP performances. Since WP is basically IND prediction for each task and most OOD techniques perform both IND prediction and OOD detection, to achieve good CIL accuracy, a strong OOD performance for each task is necessary.
Based on the theoretical guidance, several new CIL methods are designed, including techniques based on the integration of a TIL method and an OOD detection method for CIL, which outperform strong baselines in both the CIL and TIL settings by a large margin. This combination is particularly attractive because TIL has achieved no forgetting, and we only need a strong OOD technique that can perform both IND prediction and OOD detection to learn each task to achieve strong CIL results.
2 Related Work
Although numerous CL techniques have been proposed, little study has been done to provide a theoretical guidance on how to solve the problem. Existing approaches mainly belong to several categories. Using regularization kirkpatrick2017overcoming; Li2016LwF to minimize changes to model parameters learned from previous tasks is a popular approach Jung2016less; Camoriano2017incremental; zenke2017continual; ritter2018online; schwarz2018progress; xu2018reinforced; castro2018end; hu2019overcoming; Dhar2019CVPR; lee2019overcoming; ahn2019neurIPS; Liu2020; Zhu_2021_CVPR_pass. Memorizing some old examples and using them to jointly train the new task is another popular approach (called replay) Rusu2016; Lopez2017gradient; Rebuffi2017; Chaudhry2019ICLR; hou2019learning; wu2019large; rolnick2019neurIPS; NEURIPS2020_b704ea2c_derpp; rajasegaran2020adaptive; Liu2020AANets; Cha_2021_ICCV_co2l; yan2021dynamically; wang2022memory; guo2022online. Some systems learn to generate pseudo training data of old tasks and use them to jointly train the new task, called pseudo-replay Gepperth2016bio; Kamra2017deep; Shin2017continual; wu2018memory; Seff2017continual; Kemker2018fearnet; hu2019overcoming; Rostami2019ijcai; ostapenko2019learning. Orthogonal projection learns each task in an orthogonal space to other tasks zeng2019continuous; guo2022adaptive; chaudhry2020continual. Our theoretical study is applicable to any continually trained classification models.
Parameter isolation is yet another popular approach, which makes different subsets (which may overlap) of the network parameters dedicated to different tasks using masks Serra2018overcoming; ke2020continual; Mallya2017packnet; supsup2020; NEURIPS2019_3b220b43compact. This approach is particularly suited for TIL. Several methods have almost completely overcome forgetting. HAT Serra2018overcoming and CAT ke2020continual protect previous tasks by masking the important parameters to those tasks. PackNet Mallya2017packnet, CPG NEURIPS2019_3b220b43compact and SupSup supsup2020 find an isolated sub-network for each task. HyperNet von2019continual initializes task-specific parameters conditioned on task-id. ADP Yoon2020Scalable decomposes parameters into shared and adaptive parts to construct an order robust TIL system. CCLL singh2020calibrating uses task-adaptive calibration in convolution layers. Our methods designed based on the proposed theory make use of two parameter isolation-based TIL methods and two OOD detection methods. A strong OOD detection method CSI in tack2020csi helps produce very strong CIL results. CSI is based on data augmentation he2020momentum and contrastive learning chen2020simple. Excellent surveys of OOD detection include bulusu2020anomalous; geng2020recent.
Some methods have used a TIL method for CIL with an additional task-id prediction technique. iTAML rajasegaran2020itaml requires each test batch to be from a single task. This is not practical as test samples usually come one by one. CCG abati2020conditional builds a separate network to predict the task-id. Expert Gate Aljundi2016expert
constructs a separate autoencoder for each task. HyperNet
von2019continual and PR-Ent henning2021posterior use entropy to predict the task id. Since none of these papers is a theoretical study, they did not know that strong OOD detection is the key. Our methods based on OOD detection perform dramatically better.Several theoretical studies have been made on lifelong/continual learning. However, they focus on traditional generalization bound. pentina2014pac proposes a PAC-Bayesian framework to provide a learning bound on expected error in future tasks by the average loss on the observed tasks. The work in lee2021continual studies the generalization error by task similarity and karakida2022learning studies the dependence of generalization error on sample size or number of tasks including forward and backward transfer. bennani2020generalisation shows that orthogonal gradient descent gives a tighter generalization bound than SGD. Our work is very different as we focus on how to solve the CIL problem, which is orthogonal to the existing theoretical analysis.
3 CIL by Within-Task Prediction and Task-ID Prediction
This section presents our theory. It first shows that the CIL performance improves if the within-task prediction (WP) performance and/or the task-id prediction (TP) performance improve, and then shows that TP and OOD detection bound each other, which indicates that CIL performance is controlled by WP and OOD detection. This connects CL and OOD detection. Finally, we study the necessary conditions for a good CIL model, which includes a good WP, and a good TP (or OOD detection).
3.1 CIL Problem Decomposition
This sub-section first presents the assumptions made by CIL based on its definition and then proposes a decomposition of the CIL problem into two sub-problems. A CL system learns a sequence of tasks , where is the domain of task and are classes of task as , where indicates the th class in task . Let to be the domain of th class of task , where . For accuracy, we will use instead of in probabilistic analysis. Based on the definition of class incremental learning (CIL) (Sec. 1), the following assumptions are implied,
Assumption 1.
The domains of classes of the same task are disjoint, i.e., .
Assumption 2.
The domains of tasks are disjoint, i.e., .
For any ground event , the goal of a CIL problem is to learn
. This can be decomposed into two probabilities,
within-task IND prediction (WP) probability and task-id prediction (TP) probability. WP probability is and TP probability is . We can rewrite the CIL problem using WP and TP based on the two assumptions,(1) | ||||
(2) |
where means a particular task and a particular class in the task.
Some remarks are in order about Eq. 2 and our subsequent analysis to set the stage.
Remark 1.
Eq. 2 shows that if we can improve either the WP or TP performance, or both, we can improve the CIL performance.
Remark 2.
It is important to note that our theory is not concerned with the learning algorithm or the training process, but we will propose some concrete learning algorithms based on the theoretical result in the experiment section.
Remark 3.
Note that the CIL definition and the subsequent analysis are applicable to tasks with any number of classes (including only one class per task) and to online CIL where the training data for each task or class comes gradually in a data stream and may also cross task boundaries (blurry tasks bang2021rainbow) because our analysis is based on an already-built CIL model after training. Regarding blurry task boundaries, suppose dataset 1 has classes {dog, cat, tiger} and dataset 2 has classes {dog, computer, car}. We can define task 1 as {dog, cat, tiger} and task 2 as {computer, car}. The shared class dog in dataset 2 is just additional training data of dog appeared after task 1.
Remark 4.
Furthermore, CIL = WP * TP in Eq. 2 means that when we have WP and TP (defined either explicitly or implicitly by implementation), we can find a corresponding CIL model defined by WP * TP. Similarly, when we have a CIL model, we can find the corresponding underlying WP and TP defined by their probabilistic definitions.
In the following sub-sections, we develop this further concretely to derive the sufficient and necessary conditions for solving the CIL problem in the context of cross-entropy loss as it is used in almost all supervised CL systems.
3.2 CIL Improves as WP and/or TP Improve
As stated in Remark 2 above, the study here is based on a trained CIL model
and not concerned with the algorithm used in training the model. We use cross-entropy as the performance measure of a trained model as it is the most popular loss function used in supervised CL. For experimental evaluation, we use
accuracyfollowing CL papers. Denote the cross-entropy of two probability distributions
and as(3) |
For any , let to be the CIL ground truth label of , where if otherwise , . Let be the WP ground truth label of , where if otherwise , . Let be the TP ground truth label of , where if otherwise , . Denote
(4) | ||||
(5) | ||||
(6) |
where , and are the cross-entropy values of WP, CIL and TP, respectively. We now present our first theorem. The theorem connects CIL to WP and TP, and suggests that by having a good WP or TP, the CIL performance improves as the upper bound for the CIL loss decreases.
Theorem 1.
If and , we have
The detailed proof is given in Appendix A.1. This theorem holds regardless of whether WP and TP are trained together or separately. When they are trained separately, if WP is fixed and we let , , which means if TP is better, CIL is better. Similarly, if TP is fixed, we have . When they are trained concurrently, there exists a functional relationship between and depending on implementation. But no matter what it is, when decreases, CIL gets better.
Theorem 1 holds for any that satisfies or . To measure the overall performance under expectation, we present the following corollary.
Corollary 1.
Let represents the uniform distribution on
The proof is given in Appendix A.2. The corollary is a direct extension of Theorem 1 in expectation. The implication is that given TP performance, CIL is positively related to WP. The better the WP is, the better the CIL is as the upper bound of the CIL loss decreases. Similarly, given WP performance, a better TP performance results in a better CIL performance. Due to the positive relation, we can improve CIL by improving either WP or TP using their respective methods developed in each area.
3.3 Task Prediction (TP) to OOD Detection
Building on Eq. 2, we have studied the relationship of CIL, WP and TP in Theorem 1. We now connect TP and OOD detection. They are shown to be dominated by each other to a constant factor.
We again use cross-entropy to measure the performance of TP and OOD detection of a trained network as in Sec. 3.2 To build the connection between and OOD detection of each task, we first define the notations of OOD detection. We use to represent the probability distribution predicted by the th task’s OOD detector. Notice that the task prediction (TP) probability distribution is a categorical distribution over tasks, while the OOD detection probability distribution
is a Bernoulli distribution. For any
, define(7) |
In CIL, the OOD detection probability for a task can be defined using the output values corresponding to the classes of the task. Some examples of the function is a sigmoid of maximum logit value or a maximum softmax probability after re-scaling to 0 to 1. It is also possible to define the OOD detector directly as a function of tasks instead of a function of the output values of all classes of tasks, i.e. Mahalanobis distance. The following theorem shows that TP and OOD detection bound each other.
Theorem 2.
i) If , let , then . ii) If , let , then , where is an indicator function.
See Appendix A.3 for the proof. As we use cross-entropy, the lower the bound, the better the performance is. The first statement (i) says that the OOD detection performance improves if the TP performance gets better (i.e., lower ). Similarly, the second statement (ii) says that the TP performance improves if the OOD detection performance on each task improves (i.e., lower ). Besides, since converges to as ’s converge to in order of , we further know that and are equivalent in quantity up to a constant factor.
In Theorem 1, we studied how CIL is related to WP and TP. In Theorem 2, we showed that TP and OOD bound each other. Now we explicitly give the upper bound of CIL in relation to WP and OOD detection of each task. The detailed proof can be found in Appendix A.4.
Theorem 3.
If and , we have
where is an indicator function.
3.4 Necessary Conditions for Improving CIL
In Theorem 1, we showed that good performances of WP and TP are sufficient to guarantee a good performance of CIL. In Theorem 3, we showed that good performances of WP and OOD are sufficient to guarantee a good performance of CIL. For completeness, we study the necessary conditions of a well-performed CIL in this sub-section.
Theorem 4.
If , then there exist i) a WP, s.t. , ii) a TP, s.t. , and iii) an OOD detector for each task, s.t. .
The detailed proof is given in Appendix A.5. This theorem tells that if a good CIL model is trained, then a good WP, a good TP and a good OOD detector for each task are always implied. More importantly, by transforming Theorem 4 into its contraposition, we have the following statements: If for any WP, , then . If for any TP, , then . If for any OOD detector, , then . Regardless of whether WP and TP (or OOD detection) are defined explicitly or implicitly by a CIL algorithm, the existence of a good WP and the existence of a good TP or OOD detection are necessary conditions for a good CIL performance.
Remark 5.
It is important to note again that our study in this section is based on a CIL model that has already been built. In other words, our study tells the CIL designers what should be achieved in the final model. Clearly, one would also like to know how to design a strong CIL model based on the theoretical results, which also considers catastrophic forgetting (CF). One effective method is to make use of a strong existing TIL algorithm, which can already achieve no or little forgetting (CF), and combine it with a strong OOD detection algorithm (as mentioned earlier, most OOD detection methods can also perform WP). Thus, any improved method from the OOD detection community can be applied to CIL to produce improved CIL systems (see Sections 4.3 and 4.4).
Recall in Section 2, we reviewed prior works that have tried to use a TIL method for CIL with a task-id prediction method von2019continual; Aljundi2016expert; rajasegaran2020itaml; abati2020conditional; henning2021posterior. However, since they did not know that the key to the success of this approach is a strong OOD detection algorithm, they are quite weak (see Section 4).
4 New CIL Techniques and Experiments
Based on Theorem 3, we have designed several new CIL methods, each of which integrates an existing CL algorithm and an OOD detection algorithm. The OOD detection algorithm that we use can perform both within-task IND prediction (WP) and OOD detection. Our experiments have two goals: (1) to show that a good OOD detection method can help improve the accuracy of an existing CIL algorithm, and (2) to fully compare two of these methods (see some others in Sec. 4.5) with strong baselines to show that they outperform the existing strong baselines considerably.
4.1 Datasets, CL Baselines and OOD Detection Methods
Datasets and CIL Tasks. Four popular benchmark image classification datasets are used, from which six CIL problems are created following recent papers Liu2020; NEURIPS2020_b704ea2c_derpp; Zhu_2021_CVPR_pass. (1) MNIST consists of handwritten images of 10 digits with 60,000/10,000 training/testing samples. We create a CIL problem (M-5T) of 5 tasks with 2 consecutive classes/digits as a task. (2) CIFAR-10 consists of 32x32 color images of 10 classes with 50,000/10,000 training/testing samples. We create a CIL problem (C10-5T) of 5 tasks with 2 consecutive classes as a task. (3) CIFAR-100 consists of 60,000 32x32 color images with 50,000/10,000 training/testing samples. We create two CIL problems by splitting 100 classes into 10 tasks (C100-10T) and 20 tasks (C100-20T), where each task has 10 and 5 classes, respectively. (4) Tiny-ImageNet
Baseline CL Methods. We include different families of CL methods: regularization, replay, orthogonal projection, and parameter isolation. MUC Liu2020 and PASS Zhu_2021_CVPR_pass are regularization-based methods. For replay methods, we use LwF Li2016LwF, iCaRL Rebuffi2017, Mnemonics Liu_2020_CVPR, BiC wu2019large, DER++ NEURIPS2020_b704ea2c_derpp, and CoL Cha_2021_ICCV_co2l. For orthogonal projection, we use OWM zeng2019continuous. Finally, for parameter isolation, we use CCG abati2020conditional, HyperNet von2019continual, HAT Serra2018overcoming, SupSup supsup2020 (Sup), and PR henning2021posterior.222iTAML rajasegaran2020itaml is not included as it requires a batch of test data from the same task to predict the task-id. When each batch has only one test sample, which is our setting, it is very weak. For example, its CIL accuracy is only 33.5% on C100-10T. Expert Gate (EG) Aljundi2016expert is also very weak. Its CIL accuracy is only 43.2 on M-5T. They are much weaker than many baselines. DER yan2021dynamically is not included as it expands the network after each task, which is somewhat unfair to other systems as all others do not expand the network. DER can generate a large number of parameters after the last task, e.g., 117.6 millions (M) for C100-20T while our proposed methods require 44.6M (HAT+CSI) and 11.6M (Sup+CSI) (refer to Appendix H). The average accuracy of DER over the 6 CL experiments is 61.4 while our methods achieve 67.9 (HAT+CSI+c) and 64.9 (Sup+CSI+c) (refer to Tab. 3). We use the official codes for the baselines except for , CCG, and PR. For these three systems, we copy the results from their papers as the code for CCG is not released and we are unable to run and PR on our machines.
OOD Detection Methods. Two OOD detection methods are used. We combine them with the above existing CL algorithms. Both these methods can also perform within-task IND prediction (WP).
(1). ODIN: Researchers have proposed several methods to improve the OOD detection performance of a trained network by post-processing liang2018enhancing; liu2020energy; lee2018simple_md. ODIN liang2018enhancing is a representative method. It adds perturbation to input and applies a temperature scaling to the softmax output of a trained network.
(2). CSI: It is a recently proposed OOD detection technique tack2020csi that is highly effective. It is based on data and class label augmentation and supervised contrastive learning khosla2020supervised. Its rotation data augmentations create distributional shifted samples to act as negative data for the original samples for contrastive learning. The details of CSI is given in Appendix D.
4.2 Training Details and Evaluation Metrics
Training Details. For the backbone structure, we follow supsup2020; Zhu_2021_CVPR_pass; NEURIPS2020_b704ea2c_derpp. AlexNet-like architecture NIPS2012_c399862d_alexnet is used for MNIST and ResNet-18 he2016deep
is used for CIFAR-10. For CIFAR-100 and Tiny-ImageNet, ResNet-18 is also used as CIFAR-10, but the number of channels are doubled to fit more classes. All the methods use the same backbone architecture except for OWM and HyperNet, for which we use their original architectures. OWM uses AlexNet. It is not obvious how to apply the technique to the ResNet structure. HyperNet uses a fully-connected network and ResNet-32 for MNIST and other datasets, respectively. We are unable to change the structure due to model initialization arguments unexplained in the original paper. For the replay methods, we use memory buffer 200 for MNIST and CIFAR-10 and 2000 for CIFAR-100 and Tiny-ImageNet as in
Rebuffi2017; NEURIPS2020_b704ea2c_derpp. We use the hyper-parameters suggested by the authors. If we could not reproduce any result, we use 10% of the training data as a validation set to grid-search for good hyper-parameters. For our proposed methods, we report the hyper-parameters in Appendix G. All the results are averages over 5 runs with random seeds.(1). Average classification accuracy over all classes after learning the last task. The final class prediction depends prediction methods (see below). We also report forgetting rate in Appendix J.
(2). Average AUC (Area Under the ROC Curve) over all task models for the evaluation of OOD detection. AUC is the main measure used in OOD detection papers. Using this measure, we show that a better OOD detection method will result in a better CIL performance. Let be the AUC score of task . It is computed by using only the model (or classes) of task to score the test data of task as the in-distribution (IND) data and the test data from other tasks as the out-of-distribution (OOD) data. The average AUC score is: , where is the number of tasks.
It is not straightforward to change existing CL algorithms to include a new OOD detection method that needs training, e.g., CSI, except for TIL (task incremental learning) methods, e.g., HAT and Sup. For HAT and Sup, we can simply switch their methods for learning each task with CSI (see Sec.4.4).
Prediction Methods. The theoretical result in Sec. 3 states that we use Eq. 2 to perform the final prediction. The first probability (WP) in Eq. 2 is easy to get as we can simply use the softmax values of the classes in each task. However, the second probability (TP) in Eq. 2 is tricky as each task is learned without the data of other tasks. There can be many options. We take the following approaches for prediction (which are a special case of Eq. 2, see below):
(1). For those approaches that use a single classification head to include all classes learned so far, we predict as follows (which is also the approach taken by the existing papers.)
(8) |
where is the logit output of the network.
(2). For multi-head methods (e.g., HAT, HyperNet, and Sup), which use one head for each task, we use the concatenated output as
(9) |
where indicate concatenation and is the output of task .333The Sup paper proposed an one-shot task-id prediction assuming that the test instances come in a batch and all belong to the same task like iTAML. We assume a single test instance per batch. Its task-id prediction results in accuracy of 50.2 on C10-5T, which is much lower than 62.6 by using Eq. 9. The task-id prediction of HyperNet also works poorly. The accuracy by its id prediction is 49.34 on C10-5T while it is 53.4 using Eq. 9. PR uses entropy to find task-id. Among many variations of PR, we use the variations that perform the best for each dataset with exemplar-free and single sample per batch at testing (i.e., no PR-BW).
These methods (in fact, they are the same method used in two different settings) is a special case of Eq. 2 if we define as , where is the sigmoid. Hence, the theoretical results in Sec. 3 are still applicable. We present a detailed explanation about this prediction method and some other options in Appendix C. These two approaches work quite well.
4.3 Better OOD Detection Produces Better CIL Performance
The key theoretical result in Sec. 3 is that better OOD detection will produce better CIL performance. Recall our considered methods ODIN and CSI can perform both WP and OOD detection.
Applying ODIN. We first train the baseline models using their original algorithms, and then apply temperature scaling and input noise of ODIN at testing for each task (no training data needed). More precisely, the output of class in task changes by temperature scaling factor of task as
(10) |
and the input changes by the noise factor as
(11) |
where is the class with the maximum output value in task . This is a positive adversarial example inspired by goodfellow2015explaining. The values and are hyper-parameters and we use the same values for all tasks except for PASS, for which we had to use a validation set to tune (see Appendix B).
Method | OOD | AUC | CIL |
---|---|---|---|
OWM | Original | 71.31 | 28.91 |
ODIN | 70.06 | 28.88 | |
MUC | Original | 72.69 | 30.42 |
ODIN | 72.53 | 29.79 | |
PASS | Original | 69.89 | 33.00 |
ODIN | 69.60 | 31.00 | |
LwF | Original | 88.30 | 45.26 |
ODIN | 87.11 | 51.82 | |
BiC | Original | 87.89 | 52.92 |
ODIN | 86.73 | 48.65 | |
DER++ | Original | 85.99 | 53.71 |
ODIN | 88.21 | 55.29 | |
HAT | Original | 77.72 | 41.06 |
ODIN | 77.80 | 41.21 | |
HyperNet | Original | 71.82 | 30.23 |
ODIN | 72.32 | 30.83 | |
Sup | Original | 79.16 | 44.58 |
ODIN | 80.58 | 46.74 |
Tab. 1 gives the results for C100-10T. The CIL results clearly show that the CIL performance increases if the AUC increases with ODIN. For instance, the CIL of DER++ and Sup improves from 53.71 to 55.29 and 44.58 to 46.74, respectively, as the AUC increases from 85.99 to 88.21 and 79.16 to 80.58. It shows that when this method is incorporated into each task model in existing trained CIL network, the CIL performance of the original method improves. We note that ODIN does not always improve the average AUC. For those experienced a decrease in AUC, the CIL performance also decreases except LwF. The inconsistency of LwF is due to its severe classification bias towards later tasks as discussed in BiC wu2019large. The temperature scaling in ODIN has a similar effect as the bias correction in BiC, and the CIL of LwF becomes close to that of BiC after the correction. Regardless of whether ODIN improves AUC or not, the positive correlation between AUC and CIL (except LwF) verifies the efficacy of Theorem 3, indicating better OOD detection results in better CIL performances.
Applying CSI. We now apply the OOD detection method CSI. Due to its sophisticated data augmentation, supervised constrative learning and results ensemble, it is hard to apply CSI to other baselines without fundamentally change them except for HAT and Sup (SupSup) as these methods are parameter isolation-based TIL methods. We can simply replace their model for training each task with CSI wholesale (the full detail is given in Appendix D). As mentioned earlier, both HAT and SupSup as TIL methods have almost no forgetting.
Tab. 2 reports the results of using CSI and ODIN. ODIN is a weaker OOD method than CSI. Both HAT and Sup improve greatly as the systems are equipped with a better OOD detection method CSI. These experiment results empirically demonstrate the efficacy of Theorem 3, i.e., the CIL performance can be improved if a better OOD detection method is used.
CL | OOD | C10-5T | C100-10T | C100-20T | T-5T | T-10T | |||||
---|---|---|---|---|---|---|---|---|---|---|---|
AUC | CIL | AUC | CIL | AUC | CIL | AUC | CIL | AUC | CIL | ||
HAT | ODIN | 82.5 | 62.6 | 77.8 | 41.2 | 75.4 | 25.8 | 72.3 | 38.6 | 71.8 | 30.0 |
CSI | 91.2 | 87.8 | 84.5 | 63.3 | 86.5 | 54.6 | 76.5 | 45.7 | 78.5 | 47.1 | |
Sup | ODIN | 82.4 | 62.6 | 80.6 | 46.7 | 81.6 | 36.4 | 74.0 | 41.1 | 74.6 | 36.5 |
CSI | 91.6 | 86.0 | 86.8 | 65.1 | 88.3 | 60.2 | 77.1 | 48.9 | 79.4 | 45.7 |
Method | M-5T | C10-5T | C100-10T | C100-20T | T-5T | T-10T |
---|---|---|---|---|---|---|
OWM | 95.8
0.13 |
51.8
0.05 |
28.9
0.60 |
24.1
0.26 |
10.0
0.55 |
8.6
0.42 |
MUC | 74.9
0.46 |
52.9
1.03 |
30.4
1.18 |
14.2
0.30 |
33.6
0.19 |
17.4
0.17 |
PASS | 76.6
1.67 |
47.3
0.98 |
33.0
0.58 |
25.0
0.69 |
28.4
0.51 |
19.1
0.46 |
LwF | 85.5
3.11 |
54.7
1.18 |
45.3
0.75 |
44.3
0.46 |
32.2
0.50 |
24.3
0.26 |
iCaRL | 96.0
0.43 |
63.4
1.11 |
51.4
0.99 |
47.8
0.48 |
37.0
0.41 |
28.3
0.18 |
Mnemonics | 96.3
0.36 |
64.1
1.47 |
51.0
0.34 |
47.6
0.74 |
37.1
0.46 |
28.5
0.72 |
BiC | 94.1
0.65 |
61.4
1.74 |
52.9
0.64 |
48.9
0.54 |
41.7
0.74 |
33.8
0.40 |
DER++ | 95.3
0.69 |
66.0
1.20 |
53.7
1.30 |
46.6
1.44 |
35.8
0.77 |
30.5
0.47 |
CoL | 65.6 | |||||
CCG | 97.3 | 70.1 | ||||
HAT | 81.9
3.74 |
62.7
1.45 |
41.1
0.93 |
25.6
0.51 |
38.5
1.85 |
29.8
0.65 |
HyperNet | 56.6
4.85 |
53.4
2.19 |
30.2
1.54 |
18.7
1.10 |
7.9
0.69 |
5.3
0.50 |
Sup | 70.1
1.51 |
62.4
1.45 |
44.6
0.44 |
34.7
0.30 |
41.8
1.50 |
36.5
0.36 |
PR-Ent | 74.1 | 61.9 | 45.2 | |||
HAT+CSI | 94.4
0.26 |
87.8
0.71 |
63.3
1.00 |
54.6
0.92 |
45.7
0.26 |
47.1
0.18 |
Sup+CSI | 80.7
2.71 |
86.0
0.41 |
65.1
0.39 |
60.2
0.51 |
48.9
0.25 |
45.7
0.76 |
HAT+CSI+c | 96.9
0.30 |
88.0
0.48 |
65.2
0.71 |
58.0
0.45 |
51.7
0.37 |
47.6
0.32 |
Sup+CSI+c | 81.0
2.30 |
87.3
0.37 |
65.2
0.37 |
60.5
0.64 |
49.2
0.28 |
46.2
0.53 |
4.4 Full Comparison of HAT+CSI and Sup+CSI with Baselines
We now make a full comparison of the two strong systems (HAT+CSI and Sup+CSI) designed based on the theoretical results. These combinations are particularly attractive because both HAT and Sup are TIL systems and have little or no CF. Then a strong OOD method (that can also perform WP (within-task/IND prediction) will result in a strong CIL method. Since HAT and Sup are exemplar-free CL methods, HAT+CSI and Sup+CSI also do not need to save any previous task data. Tab. 3 shows that HAT and Sup equipped with CSI outperform the baselines by large margins. DER++, the best replay method, achieves 66.0 and 53.7 on C10-5T and C100-10T, respectively, while HAT+CSI achieves 87.8 and 63.3 and Sup+CSI achieves 86.0 and 65.1. The large performance gap remains consistent in more challenging problems, T-5T and T-10T. We note that Sup works very poorly on M-5T, but Sup+CSI improved it drastically, although still very weak compared to HAT+CSI.
Due to the definition of OOD in the prediction method and the fact that each task is trained separately in HAT and Sup, the outputs from different tasks can be in different scales, which will result in incorrect predictions. To deal with the problem, we can calibrate the output as and use . The optimal and for each task can be found by optimization with a memory buffer to save a very small number of training examples from previous tasks like that in the replay-based methods. We refer the calibrated methods as HAT+CSI+c and Sup+CSI+c. They are trained by using the memory buffer of the same size as the replay methods (see Sec. 4.2). Tab. 3 shows that the calibration improves from their memory free versions, i.e., without calibration. We provide the details about how to train the calibration parameters and in Appendix E.
As shown in Theorem 1, the CIL performance also depends on the TIL (WP) performance. We compare the TIL accuracies of the baselines and our methods in Tab. 4. Our systems again outperform the baselines by large margins on more challenging datasets (e.g., CIFAR100 and Tiny-ImageNet).
Method | M-5T | C10-5T | C100-10T | C100-20T | T-5T | T-10T |
---|---|---|---|---|---|---|
DER++ | 99.7
0.08 |
92.0
0.54 |
84.0
9.43 |
86.6
9.44 |
57.4
1.31 |
60.0
0.74 |
HAT | 99.9
0.02 |
96.7
0.18 |
84.0
0.23 |
85.0
0.98 |
61.2
0.72 |
63.8
0.41 |
Sup | 99.6
0.01 |
96.6
0.21 |
87.9
0.27 |
91.6
0.15 |
64.3
0.24 |
68.4
0.22 |
HAT+CSI | 99.9
0.00 |
98.7
0.06 |
92.0
0.37 |
94.3
0.06 |
68.4
0.16 |
72.4
0.21 |
Sup+CSI | 99.0
0.08 |
98.7
0.07 |
93.0
0.13 |
95.3
0.20 |
65.9
0.25 |
74.1
0.28 |
4.5 Implications for Existing CL Methods, Open-World Learning and Future Research
Implication for regularization and replay methods. Regularization-based (exemplar-free) methods try to protect important parameters of old tasks to mitigate CF. However, since the training of each task does not consider OOD detection, TP will be weak, which causes difficulty for inter-task class separation (ICS) and thus low CL accuracy. Replay-based methods are better as the replay data from old tasks can be naturally regarded as OOD data for the current task, then a better OOD model is built, which improves TP. However, since the replay data is small. the OOD model is sub-optimal, especially for earlier tasks as their training cannot see any future task data. Thus for both approaches, it will be beneficial to consider CF and OOD together in learning each task (e.g., kim2022multi).
Implication for open-world learning. Since our theory says that CL needs OOD detection, and OOD detection is also the first step in open-world learning (OWL), CL and OWL naturally work together to achieve self-motivated open-world continual learning liu2021self for autonomous learning or AI autonomy. That is, the AI agent can continually discover new tasks (OOD detection) and incrementally learn the tasks (CL) all on its own with no involvement of human engineers. Further afield, this work is also related to curiosity-driven self-supervised learning in reinforcement learning and 3D navigation.
Limitation and future work. The proposed theory provides a principled guidance on what needs to be done in order to achieve good CIL results, but it gives no guidance on how to do it. Although two example techniques are presented and evaluated, they are empirical. There are many options to define WP and TP (or OOD). An idea in (guo2022online) may be helpful in this regard. (guo2022online) argues that a continual learner should learn holistic feature representations of the input data, meaning to learn as many features as possible from the input data. The rationale is that if the system can learn all possible features from each task, then a future task does not have to learn those shared/intersecting features by modifying the parameters, which will result in less CF and also better ICS. A full representation of the IND data also improves OOD detection because the OOD score of a data point is basically the distance between the data point and the IND distribution. Only capturing a subset of features (e.g., by cross entropy) will result in poor OOD detection hu2020hrn because those missing features may be necessary to separate IND and some OOD data. In our future work, we will study how to optimize WP and TP/OOD and find the necessary conditions for them to do well.
5 Conclusion
This paper proposed a theoretical study on how to solve the highly challenging continual learning (CL) problem. class incremental learning (CIL) (the other popular CL setting is task incremental learning (TIL)). The theoretical result provides a principled guidance for designing better CIL algorithms. The paper first decomposed the CIL prediction into within-task prediction (WP) and task-id prediction (TP). WP is basically TIL. The paper further theoretically demonstrated that TP is correlated with out-of-distribution (OOD) detection. It then proved that a good performance of the two is both necessary and sufficient for good CIL performances. Based on the theoretical result, several new CIL methods were designed. They outperform strong baselines in CIL and also in TIL by a large margin. Finally, we also discussed the implications for existing CL techniques and open-world learning.
Acknowledgments
The work of Gyuhak Kim, Zixuan Ke and Bing Liu was supported in part by a research contract from KDDI, two NSF grants (IIS-1910424 and IIS-1838770), and a DARPA contract HR001120C0023.
References
Appendix A Proof of Theorems and Corollaries
a.1 Proof of Theorem 1
Proof.
Since
and
we have
∎
a.2 Proof of Corollary 1.
Proof.
a.3 Proof of Theorem 2.
Proof.
i) Assume .
For , we have
For , we have
ii) Assume .
For , by , we have
which means
For , by , we have
which means
Therefore, we have
Hence,
∎
a.4 Proof of Theorem 3.
a.5 Proof of Theorem 4.
Appendix B Additional Results and Explanation Regarding Table 1 in the Main Paper
In Sec. 4.3, we showed that a better OOD detection improves CIL performance. For the post-processing method ODIN, we only reported the results on C100-10T due to space limitations. Tab. 5 shows the results on the other datasets.
M-5T | C10-5T | C100-20T | T-5T | T-10T | |||||||
---|---|---|---|---|---|---|---|---|---|---|---|
Method | OOD | AUC | CIL | AUC | CIL | AUC | CIL | AUC | CIL | AUC | CIL |
OWM | Original | 99.13 | 95.81 | 81.33 | 51.79 | 71.90 | 24.15 | 58.49 | 10.00 | 59.48 | 8.57 |
ODIN | 98.86 | 95.16 | 71.72 | 40.65 | 68.52 | 23.05 | 58.46 | 10.77 | 59.38 | 9.52 | |
MUC | Original | 92.27 | 74.90 | 79.49 | 52.85 | 66.20 | 14.19 | 68.42 | 33.57 | 62.63 | 17.39 |
ODIN | 92.67 | 75.71 | 79.54 | 53.22 | 65.72 | 14.11 | 68.32 | 33.45 | 62.17 | 17.27 | |
PASS | Original | 98.74 | 76.58 | 66.51 | 47.34 | 70.26 | 24.99 | 65.18 | 28.40 | 63.27 | 19.07 |
ODIN | 90.40 | 74.33 | 63.08 | 35.20 | 69.81 | 21.83 | 65.93 | 29.03 | 62.73 | 17.78 | |
LwF | Original | 99.19 | 85.46 | 89.39 | 54.67 | 89.84 | 44.33 | 78.20 | 32.17 | 79.43 | 24.28 |
ODIN | 98.52 | 90.39 | 88.94 | 63.04 | 88.68 | 47.56 | 76.83 | 36.20 | 77.02 | 28.29 | |
BiC | Original | 99.40 | 94.11 | 90.89 | 61.41 | 89.46 | 48.92 | 80.17 | 41.75 | 80.37 | 33.77 |
ODIN | 98.57 | 95.14 | 91.86 | 64.29 | 87.89 | 47.40 | 74.54 | 37.40 | 76.27 | 29.06 | |
DER++ | Original | 99.78 | 95.29 | 90.16 | 66.04 | 85.44 | 46.59 | 71.80 | 35.80 | 72.41 | 30.49 |
ODIN | 99.09 | 94.96 | 87.08 | 63.07 | 87.72 | 49.26 | 73.92 | 37.87 | 72.91 | 32.52 | |
HAT | Original | 94.46 | 81.86 | 82.47 | 62.67 | 75.35 | 25.64 | 72.28 | 38.46 | 71.82 | 29.78 |
ODIN | 94.56 | 82.06 | 82.45 | 62.60 | 75.36 | 25.84 | 72.31 | 38.61 | 71.83 | 30.01 | |
HyperNet | Original | 85.83 | 56.55 | 78.54 | 53.40 | 72.04 | 18.67 | 54.58 | 7.91 | 55.37 | 5.32 |
ODIN | 86.89 | 64.31 | 79.39 | 56.72 | 73.89 | 23.8 | 54.60 | 8.64 | 55.53 | 6.91 | |
Sup | Original | 90.70 | 70.06 | 79.16 | 62.37 | 81.14 | 34.70 | 74.13 | 41.82 | 74.59 | 36.46 |
ODIN | 90.68 | 69.70 | 82.38 | 62.63 | 81.48 | 36.35 | 73.96 | 41.10 | 74.61 | 36.46 |
A continual learning method with a better AUC shows a better CIL performance than other methods with lower AUC. For instance, original HAT achieves AUC of 82.47 while HyperNet achieves 78.54 on C10-5T. The CIL for HAT is 62.67 while it is 53.40 for HyperNet. However, there are some exceptions that this comparison does not hold. An example is LwF. Its AUC and CIL are 89.39 and 54.67 on C10-5T. Although its AUC is better than HAT, the CIL is lower. This is due to the fact that CIL improves with WP and TP according to Theorem 1. The contraposition of Theorem 4 also says if the cross-entropy of TIL is large, that of CIL is also large. Indeed, the average within-task prediction (WP) accuracy for LwF on C10-5T is 95.2 while the same for HAT is 96.7. Improving WP is also important in achieving good CIL performances.
For PASS, we had to tune using a validation set. This is because the softmax in Eq. 10 improves AUC by making the IND (in-distribution) and OOD scores more separable within a task, but deteriorates the final scores across tasks. To be specific, the test instances are predicted as one of the classes in the first task after softmax because the relative values between classes in task 1 is larger than the other tasks in PASS. Therefore, larger and smaller , for , are chosen to compensate the relative values.
Appendix C Definitions of TP
As noted in the main paper, the class prediction in Eq. 2 varies by definition of WP and TP. The precise definition of WP and TP depends on implementation. Due to this subjectivity, we follow the prediction method as the existing methods in continual learning, which is the over the output. In this section, we show that the over output is a special case of Eq. 2. We also provide CIL results using different definitions of TP.
We first establish another theorem. This is an extension of Theorem 2 and connects the standard prediction method to our analysis.
Theorem 5 (Extension of Theorem 2).
i) If , let , , then .
ii) If , let , , then , where is an indicator function.
In Theorem 5 (proof appears later), we can observe that decreases with the increase of , while increases. Hence, when TP is given, let , we can find the optimal to define OOD by solving . Similarly, given OOD, let , we can find the optimal to define TP by finding the global minima of . The optimal can be found using a memory buffer to save a small number of previous data like that in a replay-based continual learning method.
In Theorem 5 (ii), let