SOLA: Continual Learning with Second-Order Loss Approximation

by   Dong Yin, et al.

Neural networks have achieved remarkable success in many cognitive tasks. However, when they are trained sequentially on multiple tasks without access to old data, it is observed that their performance on old tasks tend to drop significantly after the model is trained on new tasks. Continual learning aims to tackle this problem often referred to as catastrophic forgetting and to ensure sequential learning capability. We study continual learning from the perspective of loss landscapes and propose to construct a second-order Taylor approximation of the loss functions in previous tasks. Our proposed method does not require any memorization of raw data or their gradients, and therefore, offers better privacy protection. We theoretically analyze our algorithm from an optimization viewpoint and provide a sufficient and worst-case necessary condition for the gradient updates on the approximate loss function to be descent directions for the true loss function. Experiments on multiple continual learning benchmarks suggest that our method is effective in avoiding catastrophic forgetting and in many scenarios, outperforms several baseline algorithms that do not explicitly store the data samples.



There are no comments yet.


page 1

page 2

page 3

page 4


Probing Representation Forgetting in Supervised and Unsupervised Continual Learning

Continual Learning research typically focuses on tackling the phenomenon...

Continual Learning by Asymmetric Loss Approximation with Single-Side Overestimation

Catastrophic forgetting is a critical challenge in training deep neural ...

Orthogonal Gradient Descent for Continual Learning

Neural networks are achieving state of the art and sometimes super-human...

Center Loss Regularization for Continual Learning

The ability to learn different tasks sequentially is essential to the de...

Continual learning with direction-constrained optimization

This paper studies a new design of the optimization algorithm for traini...

Overcoming Catastrophic Forgetting with Gaussian Mixture Replay

We present Gaussian Mixture Replay (GMR), a rehearsal-based approach for...

Layerwise Optimization by Gradient Decomposition for Continual Learning

Deep neural networks achieve state-of-the-art and sometimes super-human ...
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Neural networks are achieving human-level performance on many cognitive tasks including image classification krizhevsky2012imagenet and speech recognition hinton2006fast . However, as opposed to humans, their acquired knowledge is comparably volatile and can be easily dismissed. Especially, the catastrophic forgetting phenomenon refers to the case when a neural network forgets the past tasks if it is not allowed to retrain or reiterate on them again goodfellow2013empirical ; mccloskey1989catastrophic .

Continual learning is a research direction that aims to solve the catastrophic forgetting problem. Recent works tried to tackle this issue from a variety of perspectives. Regularization methods (e.g., kirkpatrick2017overcoming ; zenke2017continual ) aim to consolidate the weights that are important to previous tasks while expansion based methods (e.g., rusu2016progressive ; yoon2018lifelong ) typically increase the model capacity to cope with the new tasks. Repetition based methods (e.g., lopez2017gradient ; chaudhry2018efficient ) usually do not require additional and complex modules, however, they have to maintain a small memory of previous data and use them to preserve knowledge. Unfortunately, the performance boost of repetition based methods comes at the cost of storing previous data which may be undesirable whenever privacy is important. To address this issue, authors in  farajtabar2019orthogonal proposed a method to work with the gradients of the previous data to constrain the weight updates; however, this may still be subject to privacy issues as the gradient associated with each individual data point may disclose information about the raw data.

In this paper, we study the continual learning problem from the perspective of loss landscapes. We explicitly target minimizing an average over all tasks’ loss functions. The proposed method stores neither the data samples nor the individual gradients on the previous tasks. Instead, we propose to construct an approximation to the loss surface of previous tasks. More specifically, we approximate the loss function by estimating its second-order Taylor expansion. The approximation is used as a surrogate added to the loss function of the current task. Our method only stores information based on the statistics of the entire training dataset, such as full gradient and full Hessian matrix (or its low rank approximation), and thus better protects privacy. In addition, since we do not expand the model capacity, the neural network structure is less complex than that of expansion based methods.

We study our algorithm from an optimization perspective, and make the following theoretical contributions:

  • We prove a sufficient and worst-case necessary condition under which by conducting gradient descent on the approximate loss function, we can still minimize the actual loss function.

  • We further provide convergence analysis of our algorithm for both non-convex and convex loss functions. Our results imply that early stopping can be helpful in continual learning.

  • We make connections between our method and elastic weight consolidation (EWC) kirkpatrick2017overcoming .

In addition, we make the following experimental contributions:

  • We conduct a comprehensive comparison among our algorithm and several baseline algorithms kirkpatrick2017overcoming ; chaudhry2018efficient ; farajtabar2019orthogonal on a variety of combinations of datasets and models. We observe that in many scenarios, especially when the learner is not allowed to store the raw data samples, our proposed algorithm outperforms them. We also discuss the conditions under which the proposed method or any of the alternatives are effective.

  • We provide experimental evidence validating the importance of accurate approximation of the Hessian matrix and discuss scenarios in which early stopping is helpful for our algorithm.

2 Related work

Avoiding catastrophic forgetting in continual learning Parisi2018ContinualLL ; beaulieu2020learning is an important milestone towards achieving artificial general intelligence (AGI) which entails developing measurements toneva2018empirical ; kemker2018measuring , evaluation protocols farquhar2018towards ; de2019continual , and theoretical understating nguyen2019toward ; farquhar2019unifying of the phenomenon. Generally speaking, three classes of algorithms exist to overcome catastrophic forgetting farajtabar2019orthogonal .

The expansion

based methods allocate new neurons or layers or modules to accommodate new tasks while utilizing the shared representation learned from previous ones 

rusu2016progressive ; xiao2014error ; yoon2018lifelong ; li2019learn ; Jerfel2018ReconcilingMA . Although being a very natural approach the mechanism of dynamic expansion can be quite complex and can add considerable overhead to the training process.

The repetition and memory based methods store previous data or, alternatively, train a generative model of them and replay samples from them interleaved with samples drawn from the current task shin2017continual ; kamra2017deep ; zhang2019prototype ; rios2018closed ; luders2016continual ; lopez2017gradient ; farajtabar2019orthogonal . They achieve promising performance however at the cost of higher risk of users’ privacy by storing or learning a generative model of their data.

The regularization based approaches impose limiting constraints on the weight updates of the neural network according to some relevance score for previous knowledge kirkpatrick2017overcoming ; nguyen2017variational ; titsias2019functional ; ritter2018online ; mirzadeh2020dropout ; zenke2017continual ; park2019continual . These methods provide a better privacy guarantee as they do not explicitly store the data samples. In general, SOLA also belongs to this category as we use the second-order Taylor expansion as the regularization term in new tasks. Many of the regularization methods are derived from a Bayesian perspective of estimating the posterior distribution of the model parameters given the data from a sequence of tasks kirkpatrick2017overcoming ; nguyen2017variational ; titsias2019functional ; ritter2018online

; some of these methods use other heuristics to either estimate the importance of the weights of the neural network 

zenke2017continual ; park2019continual or implicitly limit the capacity of the network mirzadeh2020dropout . Similar to our approach, several regularization based methods use quadratic functions as the regularization term, and many of them use the diagonal form of quadratic functions kirkpatrick2017overcoming ; zenke2017continual ; park2019continual . In Section 5.3, we demonstrate that in some cases, the EWC algorithm kirkpatrick2017overcoming can be considered as the diagonal approximation of our approach. Here, we note that the diagonal form of quadratic regularization has the drawback that it does not take the interaction between the weights into account.

Among the regularization based methods, the online Laplace approximation algorithm ritter2018online

is the most similar one to our proposed method. Despite the similarity in the implementations, the two algorithms are derived from very different perspectives: the online Laplace approximation algorithm uses a Bayesian approach that approximates the posterior distribution of the weights with a Gaussian distribution, whereas our algorithm is derived from an optimization viewpoint using Taylor approximation of loss functions. More importantly, the Gaussian approximation in 

ritter2018online is proposed as a heuristic; whereas in this paper, we provide rigorous theoretical analysis on how the approximation error affects the optimization procedure. We believe that our analysis provides deeper insights to the loss landscape of continual learning problems, and explains some important implementation details such as early stopping.

We also note that continual learning is broader than just solving the catastrophic forgetting and is connected to many other areas such as meta learning riemer2018learning , few-shot Learning  wen2018few ; gidaris2018dynamic , learning without explicit task identifiers rao2019continual ; aljundi2019online , to name a few.

3 Problem formulation

We consider a sequence of supervised learning tasks , .111For any positive integer , we define . For task , there is an unknown distribution over the space of feature-label pairs . Let be a model parameter space,222In most cases, we consider . and for the -th task, let be the loss function of associated with data point . The population loss function of task is defined as

. Our general objective is to learn a parametric model with minimized population loss over all the

tasks. More specifically, in continual learning, the learner follows the following protocol: When learning on the -th task, the learner obtains access to data points , sampled i.i.d. according to and we define as the empirical loss function; the learner then updates the model parameter using these training data, and after the training procedure is finished, the learner loses access to the training data, but can store some side information about the task. Our goal is to avoid forgetting previous tasks when trained on new tasks by utilizing the side information. We provide details of our algorithm design in the next section.

4 Our approach

To measure the effectiveness of a continual learning algorithm, we use a simple criterion that after each task, we hope the average population loss over all the tasks that have been trained on to be small, i.e., for every , after training on , we hope to solve . Since minimizing the loss function is the key to training a good model, we propose a straightforward method for continual learning: storing the second-order Taylor expansion of the empirical loss function, and using it as a surrogate of the loss function for an old task when training on new tasks. We start with a simple setting. Suppose that there are two tasks, and at the end of , the we obtain a model . Then we compute the gradient and Hessian matrix of at , and construct the second-order Taylor expansion of at :

When training on , we try to minimize . The basic idea of this design is that, we hope in a neighborhood around , the quadratic function stays as a good approximation of , and thus approximately we still minimize the average of the empirical loss functions , which in the limit generalizes to the population loss function .

We rely on the assumption that the second-order Taylor approximation of loss function can capture their local geometry well. For a general nonlinear function and arbitrary displacement, this approximation can be over-simplistic, however, we refer to the abundance of observations for modern neural networks that are seen to be well-behaved with flat and wide minima choromanska2015loss ; goodfellow2014qualitatively . Moreover, the assumption of well-behaved loss around tasks’ local minima also forms the basis of a few other continual learning algorithms such as EWC kirkpatrick2017overcoming and OGD farajtabar2019orthogonal .

Formally, let be the model that we obtain at the end of the -th task. We define the approximation of the sum of the first empirical loss functions as


where denotes the Hessian matrix or its low rank approximation, , , and is a constant that does not depend on . We construct at the end of task , and when training on task , we minimize . In the following, we name our algorithm SOLA, an acronym for Second-Order Loss Approximation.

As we can see, in the SOLA algorithm, after each task, if we choose to use the exact Hessian matrix, i.e., , , it suffices to update and in memory, and thus the memory cost of the algorithm is , which does not grow with the number of tasks. However, in practice, especially for overparameterized neural networks, the dimension of the model is usually large, and thus the storage cost of memorizing the Hessian matrix can be high. Recent studies have shown that the Hessian matrices of the loss functions of a deep neural networks are usually approximately low rank ghorbani2019investigation . If we choose as a rank- approximation of , we need to keep accumulating the low rank approximations of the Hessian matrices in order to construct , and at the end of task , the memory cost is , which in practice can be much smaller than that of using the exact Hessian matrices. We formally demonstrate our approach in Algorithm 1, and the methods that use exact Hessian matrices and low rank approximation of them are presented as options I and II, respectively. Moreover, we can use a recursive implementation for the low rank approximation, and the memory cost can be further reduced to , which does not grow with . We present the details of the recursive implementation in Section 6.

1:  Input: initial weights , learning rate , the number of tasks , the rank of Hessian approximation (for option II)
2:  for  do
3:     access training data for the -th task , ,
4:     while termination condition not satisfied do
5:        compute (stochastic) gradient of current loss
6:        compute gradient of loss function approximation ()
8:     end while
9:     , and
11:  end for
Algorithm 1 Continual learning with second-order loss approximation (SOLA)

5 Theoretical analysis

In this section, we provide theoretical analysis of our algorithm. As we can see, the key idea in our algorithm is to approximate the loss functions of previous tasks using quadratic functions. This leads to the following theoretical question: By running gradient descent algorithm on an approximate loss function, can we still minimize the actual loss that we are interested in?

For the purpose of theoretical analysis, we make a few simplifications to our setup. Without loss of generality, we study the training process of the last task , and still use to denote the model parameters that we obtain at the end of the -th task. We use the loss function approximation in (1), but for simplicity we ignore the finite-sample effect and replace the empirical loss function with the population loss function, i.e., we define


where represents or its low rank approximation. The reason for this simplification is that our focus is the optimization aspect of the problem, while the generalization aspect can be tackled by tools such as uniform convergence mohri2018foundations . As discussed, during the training of the last task, we have access to the approximate loss function , whereas the actual loss function that we care about is . We also focus on gradient descent instead of its stochastic counterpart. In particular, let be the initial model parameter for the last task. We run the following update for :


We use the following standard notions for differentiable function .

Definition 1.

is -smooth if , .

Definition 2.

is -Hessian Lipschitz if , .

We make the assumptions that the loss functions are smooth and Hessian Lipschitz. We note that the Hessian Lipschitz assumption is standard in analysis of non-convex optimization nesterov2006cubic ; jin2017escape .

Assumption 1.

We assume that is -smooth and -Hessian Lipschitz .

We also assume that the error between the matrices and is bounded.

Assumption 2.

We assume that for every , , where is defined in Assumption 1, and that for some .

5.1 Sufficient and worst-case necessary condition for one-step descent

We begin with analyzing a single step during training. Our goal is to understand by running a single step of gradient descent on , whether we can minimize the actual loss function . More specifically, we have the following result.

Theorem 1.

Under Assumptions 1 and 2, and suppose that in the -th iteration, we observe


and the learning rate satisfies , then we have

We prove Theorem 1 in Appendix A. Here, we emphasize that this result does not assume any convexity of the loss functions. The theorem provides a sufficient condition (4), under which by running gradient descent on , we can still minimize the true loss function . Intuitively, this condition requires the gradient of to be large enough, such that the magnitude of the gradient is larger than the error caused by the inexactness of the loss function. In Proposition 1 below, we will see that this condition is also necessary in the worst-case scenario, at least for the case where . More specifically, we can construct cases in which (4) is violated and the gradients of and have opposite directions.

Proposition 1.

Suppose that , , . Then, there exists , , , and such that if , then .

We prove Proposition 1 in Appendix B. In addition, we note that Theorem 1 also implies that as training going on and decreasing, it is beneficial to decrease the learning rate , since when decreases, the upper bound on that guarantees the decay of (i.e., ) also decreases. We notice that the importance of learning rate decay for continual learning has been observed in some empirical study recently mirzadeh2020dropout .

5.2 Convergence analysis

Although the condition in (4) provides us with insights on the dynamics of the training algorithm, it is usually hard to check this condition in every step, since we may not have good estimates of and . A practical implementation is to choose a constant learning rate along with an appropriate number of training steps. In this section, we provide bounds on the convergence behavior of our algorithm with a constant learning rate and iterations, both for non-convex and convex loss functions. These results imply that early stopping can be helpful, and provide a theoretical treatment of the very intuitive fact that the more iterations one optimizes for the current task the more forgetting can happen for the previous ones. We begin with a convergence analysis for non-convex loss functions in Theorem 2, in which we use the common choice of learning rate for gradient descent on smooth functions bubeck2014convex .

Theorem 2 (non-convex).

Let , , , and . Then, under Assumptions 1 and 2, after running iterations of the gradient descent update (3) with learning rate , we have

where , , , and .

We prove Theorem 2 in Appendix C. Unlike standard optimization analysis, the average norm of the gradients does not always decrease as increases, when or . Intuitively, as we move far from the points where we conduct Taylor expansion, the gradient of becomes more and more inaccurate, and thus we need to stop early. In Section 7, we provide experimental evidence.

When the loss functions are convex, we can prove a better guarantee which does not have the and terms as in Theorem 2. More specifically, we have the following assumption and theorem.

Assumption 3.

is convex and , .

Theorem 3 (convex).

Suppose that Assumptions 12, and 3 hold, and define , , , and . After running iterations of the gradient descent update (3) with learning rate , we have

where , and .

We prove Theorem 3 in Appendix D. As we can see, if or , we still cannot guarantee the convergence to the true minimum of , due to the inexactness of . On the other hand, if the loss functions are quadratic and we save the full Hessian matrices, i.e., , as we have full information about previous loss functions, we can recover the standard convergence rate for gradient descent on convex and smooth functions.

5.3 Connection to EWC

The elastic weight consolidation (EWC) algorithm kirkpatrick2017overcoming for continual learning is proposed based on the Bayesian idea of estimating the posterior distribution of the model parameters. Interestingly, we notice that our algorithm has a connection with EWC, although their basic ideas are quite different. More specifically, we show that in some cases, the regularization technique that the EWC algorithm uses can be considered as a diagonal approximation of the Hessian matrix of the loss function. Suppose that in the -th task, the data points are samples from a probabilistic model with the likelihood function being , and we use negative log-likelihood as the loss function, i.e., . Suppose that at the end of this task, we obtain the ground truth model parameter . Then we know that , and that the Fisher information of the -th coordinate of is . The EWC algorithm constructs a regularization term as a proxy of the loss function of the -th task, and uses it in the following tasks. As we can see, in this case, the quadratic regularization in EWC is a diagonal approximation of the quadratic term in our loss function approximation approach.

6 A recursive implementation

As we have seen, one drawback of the SOLA algorithm with low rank approximation in Section 4 is that the memory cost grows with the number of tasks. In this section, we present a more practical and memory efficient implementation of SOLA with low rank approximation. Recall that is the empirical loss function for the -th task, . We then define the loss function approximation in a recursive way. We begin with , and for every , we define


where is a rank- approximation of the Hessian matrix . This means that at the end of task , we compute the second-order Taylor expansion of the approximate loss function at , with the Hessian matrix being replaced by the low rank approximation . Thus, we only need to store and , and the memory cost is , which does not grow with . We formally present this approach in Algorithm 2. In our experiments in Section 7, we use the recursive implementation for SOLA with low rank approximation.

1:  Input: initial weights , learning rate , the number of tasks , rank
2:  for  do
3:     access training data for the -th task , ,
4:     while termination condition not satisfied do
5:        compute (stochastic) gradient of current loss
6:        compute gradient of loss function approximation ()
8:     end while
9:     , , rank- approximation of
11:  end for
Algorithm 2 Recursive implementation of SOLA with low rank approximation

7 Experiments

We implement the experiments with TensorFlow 


. When computing the exact or the low rank approximation of the Hessian matrix, we treat each tensor in the model independently; in other words, we compute the block diagonal approximation of the Hessian matrix. This technique has the benefit that the Hessian computation is independent of the model architecture and has been used in recent studies on second-order optimization 

gupta2018shampoo . We use the recursive implementation for SOLA with low rank approximation. In the following, we denote SOLA with exact Hessian matrix and low rank approximation by SOLA-exact and SOLA-prox

, respectively. As for the calculation of the low rank matrix, we make use of Hessian-vector product and provide details in Appendix 


Datasets. We use multiple standard continual learning benchmarks created based on MNIST lecun1998gradient and CIFAR-10 krizhevsky2009learning datasets, i.e., Permuted MNIST goodfellow2013empirical , Rotated MNIST lopez2017gradient , Split MNIST zenke2017continual , and Split CIFAR (similar to a dataset in chaudhry2018efficient ). In Permuted MNIST, for each task, we choose a random permutation of the pixels of MNIST images, and reorder all the images according to the permutation. We use -task Permuted MNIST in the experiments. In Rotated MNIST, for each task, we rotate the MNIST images by a particular angle. In our experiments, we choose a -task Rotated MNIST, with the rotation angles being , , , , and degrees. For Split MNIST, we Split the labels of the MNIST dataset to disjoint subsets, and for each task, we use the MNIST data whose labels belong to a particular subset. In this paper, we use a -task Split MNIST, and the subsets of labels are , , , , and . Split CIFAR is defined similar to Split MNIST, and we use a -task Split CIFAR with the label subsets being and .


We use both multilayer perceptron (MLP) and convolutional neural network (CNN). In most cases, we use MLP with two hidden layers, sometimes denoted by MLP

, with and being the number of hidden units. We may use CNN- to denote a CNN model with convolutional layers, and provide details of the model in Appendix F. For Split MNIST and Split CIFAR, we use MLP and CNN models with a multi-head structure similar to what has been used in chaudhry2018efficient ; farajtabar2019orthogonal . In the multi-head model, instead of having logits in the output layer, we use separate heads for different tasks, and each head corresponds to the classes of the associated task. During training, for each task, we only optimize the cross-entropy loss over the logits and labels of the corresponding output head.

Baselines. We compare SOLA algorithm with several baselines: the vanilla algorithm which runs SGD over all the tasks without storing any side information; the multi-task algorithm which assumes access to all the training data of previous tasks; the repetition based A-GEM algorithm chaudhry2018efficient , which stores a subset of data samples from the previous tasks and forms constrained optimization algorithms when training on new tasks; the regularization based EWC algorithm kirkpatrick2017overcoming discussed in Section 5.3; and the orthogonal gradient descent (OGD) algorithm farajtabar2019orthogonal that stores the gradients in previous tasks and forms a constrained optimization algorithm.333Comparison between OGD and SOLA-prox: Since OGD has a memory cost that grows linearly with the number of tasks but SOLA-prox does not, we keep the average memory cost of them the same. Among them, our algorithm, along with the vanilla, EWC, and OGD algorithms do not explicitly store the raw data samples. Following prior works chaudhry2018efficient ; farajtabar2019orthogonal ; kirkpatrick2017overcoming , we choose a learning rate of and a batch size of . For all the results that we report, we present the average result over

independent runs, as well as the standard deviation (as the shaded areas in the figures).

Model size [10, 10] [100, 100] [10, 10] [100, 100] 4-conv [10, 10] [100, 100]
Multi-task 91.8 0.4 97.0 0.1 91.4 0.4 97.5 0.1 98.8 0.1 98.9 0.3 99.3 0.1
A-GEM 84.1 1.1 93.2 0.4 83.6 1.0 92.6 0.4 95.3 0.3 91.2 4.9 97.8 0.4
Vanilla 69.2 3.1 81.1 1.6 76.8 0.9 86.0 0.5 89.5 0.6 86.4 6.6 97.2 0.9
EWC 69.1 3.7 80.2 1.4 76.9 1.0 86.1 0.6 89.4 0.7 87.7 9.2 97.7 0.8
OGD 68.9 3.3 81.5 1.7 81.1 1.3 88.0 0.7 89.5 0.7 97.1 1.8 98.8 0.1
SOLA-exact 90.0 0.9 88.6 0.9 96.3 3.0
SOLA-prox 86.2 1.5 87.8 0.6 86.5 0.9 90.4 0.5 92.2 1.5 96.1 2.5 99.0 0.2
Table 1: Average test accuracy (%) std. P-MNIST, R-MNIST, and S-MNIST represent Permuted, Rotated, and Split MNIST datasets, respectively. Boldface numbers correspond to best result among algorithms that do not store raw data points, i.e., excluding multi-task and A-GEM. For MLP, in OGD, we store gradient samples for each task, and in SOLA-prox, we use ; for other models, in OGD, we store gradient samples for each task, and in SOLA-prox, we use .
Model Multi-task A-GEM Vanilla EWC OGD SOLA-exact SOLA-prox
CNN-2 75.9 0.9 65.8 2.1 57.2 4.2 55.6 4.6 56.5 4.2 62.0 5.4 59.4 3.8
CNN-6 78.6 1.4 68.1 2.3 57.5 4.6 57.7 3.8 58.3 4.8 58.6 5.2
MLP 69.2 0.5 66.1 0.7 63.5 1.6 63.8 2.1 65.8 1.2 55.7 3.2
Table 2: Average test accuracy (%) std on Split CIFAR. Boldface numbers correspond to best result among algorithms that do not store raw data, i.e., excluding multi-task and A-GEM. For OGD we store gradient samples for each task, and for SOLA-prox, we choose .

Results. We provide a comprehensive comparison among SOLA and the baseline algorithms with a variety of combinations of datasets and models. Tables 1 and 2 present the results for MNIST-based datasets and Split CIFAR, respectively. We make a few notes before discussing the results. First, the multi-task algorithm uses all the data of previous tasks, which serves as an upper bound for the performance of continual learning algorithms. Second, since the A-GEM algorithm stores a subset of data samples from previous tasks, it is not completely fair to compare A-GEM with algorithms that do not store raw data. However, here we still report the results for A-GEM for reference, and in A-GEM we store

data points for each task. Third, since the performance of the algorithms depends on the number of epochs that we train for each task, we treat this quantity as a tuning parameter, and for each algorithm, we report the result corresponding to the best epoch choice for its performance. In particular, for MNIST-based datasets, we choose epoch from

, and for Split CIFAR, we choose from . Due to memory constraints, we only implement SOLA-exact on small models such as MLP and CNN-2. We conclude from the results as follows:

  • If it is allowed to store raw data, repetition based algorithm such as A-GEM should be the choice. This remarks the importance of the information contained in the raw data samples. In some cases we observe that SOLA outperforms A-GEM, e.g., on MLP. However, we expect that the performance of A-GEM can be improved if more data are stored in memory.

  • If it is not allowed to store raw data due to privacy concerns, then in many scenarios, SOLA outperforms the baseline algorithms. In particular, on MNIST-based datasets, SOLA-exact or SOLA-prox achieves the best performance in out of settings.

  • On Split CIFAR, we observe mixed results. When the model is relatively small (CNN-2) and we can store the exact Hessian matrix, SOLA-exact achieves the best performance. On a relatively large CNN model (CNN-6), we observe that none of the continual learning algorithms (EWC, OGD, SOLA) significantly outperforms the vanilla algorithm. On a large MLP, we observe that OGD performs the best and the result for SOLA-prox becomes worse. We believe the reason is that since in this experiment we only use eigenvectors to approximate a Hessian matrix with very high dimensions, the approximation error is so large that SOLA-prox cannot find a descent direction that is close to the true gradient. This remarks the importance of future study of SOLA on models with more complicated structure or higher dimensions.

Performance vs approximation. We study how the approximation of Hessian matrices affects the performance of SOLA-prox. In particular, we choose different values of the rank in SOLA-prox and investigate its correlation with the final average test accuracy. Our theory implies that when the approximation of Hessian matrices is better, i.e., smaller , the final performance is better. Our experiments validate this point. Figure 0(a) and Figure 0(b) show that, as we increase , i.e., using more eigenvectors to approximate the Hessian matrix, the average test accuracy over all tasks improves.

Figure 1: (a)(b): Average test accuracy of SOLA-prox vs the rank for the approximation of Hessian matrices. We choose the number of epochs per task in , and observe that the average test accuracy improves as we use more eigenvectors. (c)(d): Average test accuracy vs number of epochs per task. For SOLA-prox, on Permuted MNIST, we choose , and on Rotated MNIST, we choose . On Permuted MNIST, the test accuracy of SOLA-prox becomes worse if we train more than epochs per task, and on Rotated MNIST, the test accuracy of SOLA-prox gradually decreases as we increase the number of epochs per task.

Early stopping. Our theoretical analysis in Section 5 implies that early stopping can be helpful for SOLA. Here, we discuss empirical evidence. As we can see from Figure 0(c), on Permuted MNIST with MLP, the average test accuracy of SOLA-prox becomes worse if we train more than epochs per task; similarly, from Figure 0(a), we can also see that training each task for more epochs can hurt the performance of MLP on Permuted MNIST. However, this phenomenon is less severe on Rotated MNIST. In Figure 0(d), for SOLA-prox with , we observe one case where the average test accuracy gradually decreases as we increase the number of epochs per task. Moreover, we notice that we did not observe this phenomenon for SOLA-exact. Hence, we draw the conclusion that the importance of early stopping for SOLA depends on how different the tasks are and how well we approximate the Hessian matrix. In Permuted MNIST, the pixels are randomly shuffled when switching to new tasks, whereas in Rotated MNIST we only rotate the images by degrees; thus early stopping is more important for Permuted MNIST. On the other hand, if we store the exact value of the Hessian matrix ( in Theorem 1), the approximation error of the gradients can be small, and thus we can train more epochs on new tasks. In addition, we note that it has been observed that early stopping is typically helpful for other continual learning algorithms farajtabar2019orthogonal .

8 Conclusions

We propose the SOLA algorithm based on the idea of loss function approximation. We establish theoretical guarantees, make connections to the EWC algorithm, and present experimental results showing that in many scenarios, our algorithm outperforms several baseline algorithms, especially among the ones that do not explicitly store the raw data samples. Future directions include studying SOLA on broader classes of neural network architectures and parameter spaces with higher dimensions.


We would like to thank Dilan Gorur, Alex Mott, Clara Huiyi Hu, Nevena Lazic, Nir Levine, and Michalis Titsias for helpful discussions.



Appendix A Proof of Theorem 1

We first provide a bound for the difference between the gradients of and .

Lemma 1.

Let . Then we have

We prove Lemma 1 in Appendix A.1. Since the loss functions for all the tasks are -smooth, we know that is also -smooth. Then we have

Therefore, as long as for some , we have


Then we can complete the proof by combining (7) with Lemma 1.

a.1 Proof of Lemma 1

By the definition of , for some , , we have

where the second equality is due to Lagrange’s mean value theorem. Then, according to Assumptions 1 and 2, we have


Then, according to triangle inequality, we obtain

Appendix B Proof of Proposition 1

We first note that it suffices to construct and , as one can always choose and then the construction of and is equivalent to that of and . Let ,

One can easily check that , , and . In addition, since the second derivative of is always bounded in , we know that is smooth. Since , we know that is -Hessian Lipschitz. Therefore, and satisfy all of our assumptions.

Since , we know that , , and then

is equivalent to , which implies that .

Appendix C Proof of Theorem 2

Similar to Appendix A, we define . According to Assumptions 1 and 2, we know that both and are -smooth. By the smoothness of and using the fact that , we get

which implies


By averaging (9) over , we get

By taking square root on both sizes, and using Cauchy-Schwarz inequality as well as the fact that , we get


We then proceed to bound . According to Lemma 1, we have