Uncertainty-guided Continual Learning with Bayesian Neural Networks

06/06/2019 ∙ by Sayna Ebrahimi, et al. ∙ berkeley college 0

Continual learning aims to learn new tasks without forgetting previously learned ones. This is especially challenging when one cannot access data from previous tasks and when the model has a fixed capacity. Current regularization-based continual learning algorithms need an external representation and extra computation to measure the parameters' importance. In contrast, we propose Uncertainty-guided Continual Bayesian Neural Networks (UCB), where the learning rate adapts according to the uncertainty defined in the probability distribution of the weights in networks. Uncertainty is a natural way to identify what to remember and what to change as we continually learn, allowing to mitigate catastrophic forgetting. We also show a variant of our model, which uses uncertainty for weight pruning and retains task performance after pruning by saving binary masks per tasks. We evaluate our UCB approach extensively on diverse object classification datasets with short and long sequences of tasks and report superior or on-par performance compared to existing approaches. Additionally, we show that our model does not necessarily need task information at test time, i.e. it does not presume knowledge of which task a sample belongs to.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Humans can easily accumulate and maintain knowledge gained from previously observed tasks, and continuously learn to solve new problems or tasks. Artificial learning systems typically forget prior tasks when they cannot access all training data at once but are presented with task data in sequence.

Overcoming these challenges is the focus of continual learning, sometimes also referred to as lifelong learning or sequential learning. Catastrophic forgetting mccloskey1989catastrophic ; mcclelland1995there refers to the significant drop in the performance of a learner when switching from a trained task to a new one. This phenomenon occurs because trained parameters on the initial task change in favor of learning new objectives. This is the reason that naive finetuning intuitively suffers from catastrophic forgetting.

Given a network of limited capacity, one way to address this problem is to identify the importance of each parameter and penalize further changes to those parameters that were deemed to be important for the previous tasks ewc ; mas ; SI . An alternative is to freeze the most important parameters and allow future tasks to only adapt the remaining parameters to new tasks packnet . Such models rely on the explicit parametrization of importance. We propose here implicit uncertainty-guided importance representation.

Bayesian neural networks blundell2015weight

propose an intrinsic importance model based on weight uncertainty. These networks represent each parameter with a distribution defined by a mean and variance over possible values drawn from a shared latent probability distribution. Variational inference can approximate posterior distributions using Monte Carlo sampling for gradient estimation. These networks act like ensemble methods in that they reduce the prediction variance but only use twice the number of parameters present in a regular neural network. We propose the use of the predicted mean and variance of the latent distributions to characterize the importance of each parameter. We perform continual learning with Bayesian neural networks by controlling the learning rate of each parameter as a function of its uncertainty. Figure

1 illustrates how posterior distributions evolve for certain and uncertain weight distributions while learning two consecutive tasks. Intuitively, the more uncertain a parameter is, the more learnable it can be and therefore, higher gradient steps can be taken for it to learn the current task. As a hard version of this regularization technique, we also show that pruning, i.e., preventing the most important model parameters from any change and learning new tasks with the remaining parameters, can be also integrated into UCB. We refer to this method as UCB-P.

Figure 1: Illustration of the evolution of weight distributions – uncertain weights adapt more quickly – when learning two tasks using UCB. (a) weight parameter initialized by distributions initialized with mean and variance values randomly sampled from . (b) posterior distribution after learning task 1; while and exhibit lower uncertainties after learning the first task, , , and have larger uncertainties, making them available to learn more tasks. (c)a second task is learned using higher learning rates for previously uncertain parameters (, , , and ) while learning rates for and are reduced. Size of the arrows indicate the magnitude of the change of the distribution mean upon gradient update.

Contributions: We propose to perform continual learning with Bayesian neural networks and develop a new method which exploits the inherent measure of uncertainty therein to adapt the learning rate of individual parameters (Sec. 4). Second, we introduce a hard-threshold variant of our method that decides which parameters to freeze (Sec. 4.2). Third, in Sec. 5, we extensively validate our approach experimentally, comparing it to prior art both on single datasets split into different tasks, as well as for the more difficult scenario of learning a sequence of different datasets. Forth, in contrast to most prior work, our approach does not rely on knowledge about task boundaries at inference time, which humans do not need and might not be always available. We show in Sec. 6

that our approach naturally supports this scenario and does also not require task information at test time, sometimes also referred to as a “single head” scenario for all tasks. We refer to evaluation metric of a “single head" model without task information at test time as “generalized accuracy". The PyTorch code can be found in the supplemental material and will be published upon acceptance.

2 Related Work

Conceptually, approaches to continual learning can be divided into the following categories: dynamic architectural methods, memory-based methods, and regularization methods.

Dynamic architectural methods: In this setting, the architecture grows while keeping past knowledge fixed and storing new knowledge in different forms such as additional layers, nodes, or modules. In this approach, the objective function remains fixed whereas the model capacity grows –often exponentially– with the number of tasks. Progressive networks pnn ; schwarz2018progress

was one of the earliest works in this direction and was successfully applied to reinforcement learning problems; the base architecture was duplicated and lateral connections added in response to new tasks. Dynamically Expandable Network (DEN)

yoon2018lifelong also expands its network by selecting drifting units and retraining them on new tasks. In contrast to our method, these approaches require the architecture grow with each new task.

Memory-based methods: In this regime, previous information is partially stored to be used later as a form of rehearsal robins1995catastrophic . Gradient episodic memory (GEM) gem

uses this idea to store the data at the end of each episode to be used later to prevent gradient updates from deviating from their previous values. GEM also allows for positive backward knowledge transfer, i.e, an improvement on previously learned tasks, and it was the first method capable of learning using a single training example. In the realm of Bayesian deep learning, VCL

vcl

uses Bayesian inference to perform continual learning where new posterior distribution is simply obtained by multiplying the previous posterior by the likelihood of the dataset belonging to the new task. They also showed that by using a core-set, a small representative set of data from previous tasks, VCL can experience less forgetting. In contrast, we rely on Bayesian neural networks to use their predictive uncertainty to perform continual learning. Moreover, we do not use episodic memory to store examples in our approach. A fast natural gradient descent method for variational inference was introduced in

NGD in which, the Fisher Information matrix is approximated using the generalized Gauss-Newton method. In contrast, in our work, we use vanilla gradient descent. Although second order optimization algorithms are proven to be more accurate than the first order methods, they add considerable computational cost. ngd-vcl ; fast-ngd-vcl both investigate the effect of natural gradient descent methods as an alternative to vanilla gradient descent used in VCL and EWC methods.

Regularization methods: In these approaches, significant changes to the representation learned for previous tasks are prevented. This can be performed through regularizing the objective function or directly enforced on weight parameters. Typically, this importance measure is engineered to represent the importance of each parameter. Inspired by Bayesian learning, in elastic weight consolidation (EWC) method ewc important parameters are those to have the highest in terms of the Fisher information matrix. In Synaptic Intelligence (SI) SI

this parameter importance notion is engineered to correlate with the loss function: parameters that contribute more to the loss are more important. Similar to SI, Memory-aware Synapses (MAS)

mas proposed an online way of computing importance adaptive to the test set using the change in the model outputs w.r.t the inputs. While all the above algorithms are task-dependent, in parallel development to this work, mas-tf has recently investigated task-free continual learning by building upon MAS and using a protocol to update the weights instead of waiting until the tasks are finished. PackNet packnet used iterative pruning to fully restrict gradient updates on important weights via binary masks. This method requires knowing which task is being tested to use the appropriate mask. PackNet also ranks the weight importance by their magnitude which is not guaranteed to be a proper importance indicative. HAT hat

identifies important neurons by learning an attention vector to the task embedding to control the gradient propagation. It maintains the information learned on previous tasks using an almost-binary mask per previous tasks.

3 Background: Variational Bayes-by-Backprop

In this section, we review the Bayesian neural networks learning framework which was first introduced by blundell2015weight ; to learn a probability distribution over network parameters. blundell2015weight showed a back-propagation-compatible algorithm which acts as a regularizer and yields comparable performance to dropout on the MNIST dataset. In the next section (Section 4) we show how to use the derived parameter uncertainty to estimate the importance of a parameter for the current task, which we use to determine how much it should change when learning later tasks.

3.1 Variational inference in neural networks

In Bayesian models, latent variables are drawn from a prior density which are related to the observations through the likelihood . During inference, the posterior distribution

is computed conditioned on the given input data. However, in practice, this probability distribution is intractable and is often estimated through approximate inference. Markov Chain Monte Carlo (MCMC) sampling

hastings1970monte has been widely used and explored for this purpose, see robert2013monte for different methods under this category. However, MCMC algorithms, despite providing guarantees for finding asymptotically exact samples from the target distribution, are not suitable for large datasets and/or large models as they are bounded by speed and scalability issues. Alternatively, variational inference provides a faster solution to the same problem in which the posterior is approximated using optimization rather than being sampled from a chain peterson1987mean ; hinton1993keeping ; jaakkola1996computing ; jaakkola1997variational . Variational inference methods always take advantage of fast optimization techniques such as stochastic methods or distributed methods, which allow them to explore data models quickly. See blei2017variational for a complete review of the theory.

3.2 Bayes by Backprop (BBB)

Let be a set of observed variables and be a set of latent variables. A neural network, as a probabilistic model , given a set of training examples can output which belongs to a set of classes by using the set of weight parameters

. Variational inference aims to calculate this conditional probability distribution over the latent variables by finding the closest proxy to the exact posterior by solving an optimization problem.

We first assume a family of probability densities over the latent variables parametrized by , i.e., . We then find the closest member of this family to the true conditional probability of interest by minimizing the Kullback-Leibler (KL) divergence between and :

(1)

Once solved, would be the closest approximation to the true posterior. Eq. 1 is commonly known as variational free energy or expected lower bound:

(2)

Eq. 2 can be approximated using Monte Carlo samples from the variational posterior blundell2015weight :

(3)

We assume to have a Gaussian pdf with diagonal covariance and parametrized by . A sample weight of the variational posterior can be obtained by sampling from a unit Gaussian and reparametrized by where

is a pointwise multipliation. Standard deviation is parametrized as

and thus is always positive. For the prior, as suggested by blundell2015weight , a scale mixture of two Gaussian pdfs are chosen which are zero-centered while having different variances of and . The uncertainty obtained for every parameter has been successfully used in model compression han2015deep and uncertainty-based exploration in reinforcement learning blundell2015weight . In this work we propose to use this framework to learn sequential tasks without forgetting using per-weight uncertainties.

4 Uncertainty-guided Continual Learning in Bayesian Neural Networks

In this section, we introduce our Uncertainty-guided Continual learning approach with Bayesian neural networks (UCB), which exploits estimated uncertainty of the parameters’ posterior distribution to regulate the change in “important” parameters both in a soft way (Section 4.1) or setting a hard threshold (Section 4.2).

4.1 Ucb with learning rate regularization

A common strategy to perform continual learning is to reduce forgetting by regularizing further changes in the model representation based on parameters’ importance. In UCB the regularizing is performed with the learning rate such that the learning rate of each parameter and hence its gradient update becomes a function of its importance. As shown in the following equations, in particular, we scale the learning rate of and for each parameter distribution inversely proportional to its importance to reduce changes in important parameters while allowing less important parameters to alter more in favor of learning new tasks.

(4)
(5)

The core idea of this work is to base the definition of importance on the well-defined uncertainty in parameters distribution of Bayesian neural networks, i.e., set the importance to be inversely proportional to the standard deviation which represents the parameter uncertainty in the Baysian neural network:

(6)

We explore different options to set in our ablation study presented in Section B of the appendix, Table 5. We empirically found that and not adapting the learning rate for (i.e. ) yields the highest accuracy and the least forgetting.

The key benefit of UCB with learning rate as the regularizer is that it neither requires additional memory, as opposed to pruning technique nor tracking the change in parameters with respect to the previously learned task, as needed in common weight regularization methods.

1:  Require Training data for all tasks , (mean of posterior) (std of posterior), and (std for the scaled mixture Gaussian pdf of prior), (weighting factor for prior), (number of samples), (Number of minibatches)
2:  Require hyper parameters for training: initial learning rate ()
3:  
4:                                                                                            Ensures is always positive
5:                                                                       A posterior sample of weights
6:  
7:  for every task do
8:      repeat
9:                                                                                Log-posterior
10:                               Log-prior
11:                                                                                Log-likelihood of data
12:          
13:          
14:          
15:      until  loss plateaus
16:       LearningRateUpdate()          See Algorithm 2 for UCB and 3 for UCB-P
17:  end for
Algorithm 1 Uncertainty-guided Continual Learning with Bayesian Neural Networks UCB
1:  function LearningRateUpdate(
2:      for each parameter do
3:          
4:          
5:          
6:          
7:      end for
8:  end function
Algorithm 2 LearningRateUpdate in UCB
1:  function LearningRateUpdate(
2:      for each parameter in each layer  do
3:           Signal to noise ratio
4:          if  top of s in  then
5:              
6:          end if
7:      end for
8:  end function
Algorithm 3 LearningRateUpdate in UCB-P

More importantly, this method does not need to be aware of task switching as it only needs to adjust the learning rates of the means in the posterior distribution based on their current uncertainty. The complete algorithm for UCB is shown in Algorithm 1 with parameter update function given in Algorithm 2.

4.2 Ucb using weight pruning (Ucb-P)

In this section, we introduce a variant of our method, UCB-, is related to recent efforts in weight pruning in the context of reducing inference computation and network compression liu2017learning ; molchanov2016pruning . More specifically, weight pruning has been recently used in continual learning packnet , where the goal is to continue learning multiple tasks using a single network’s capacity. packnet accomplished this by freeing up parameters deemed to be unimportant to the current task according to their magnitude. Forgetting is prevented in pruning by saving a task-specific binary mask of important vs. unimportant parameters. Here, we adapt pruning to Bayesian neural networks. Specifically, we propose a different criterion for measuring importance: the statistically-grounded uncertainty defined in Bayesian neural networks.

Unlike regular deep neural networks, in a BBB model weight parameters are represented by probability distributions parametrized by their mean and standard deviation. Similar to blundell2015weight , in order to take into account both mean and standard deviation, we use the signal-to-noise ratio (SNR) for each parameter defined as

(7)

SNR is a commonly used measure in signal processing to distinguish between “useful” information from unwanted noise contained in a signal. In the context of neural models, the SNR can be thought as an indicative of parameter importance; the higher the SNR, the more effective or important the parameter is to the model predictions for a given task.

UCB-, as shown in Algorithms 1 and 3, is performed as follows: for every layer, convolutional or fully-connected, the parameters are ordered by their SNR value and those with the lowest importance are pruned (set to zero). The pruned parameters are marked using a binary mask so that they can be used later in learning new tasks whereas the important parameters remain fixed throughout training on future tasks. Once a task is learned, an associated binary mask is saved which will be used during inference to recover key parameters and hence the exact performance to the desired task.

The overhead memory per parameter in encoding the mask as well as saving it on the disk is as follows. Assuming we have tasks to learn using a single network, the total number of required bits to encode an accumulated mask for a parameter is at max bits assuming a parameter deemed to be important from task and kept being encoded in the mask.

5 Results

5.1 Experimental Setup

Datasets: We evaluate our approach in two common scenarios for continual learning: 1) class-incremental learning of a single or two randomly alternating datasets, where each task covers only a subset of the classes in a dataset, and 2) continual learning of multiple datasets, where each task is a dataset. We use MNIST split and permuted MNIST pmnist for class incremental learning with similar experimental settings as used in hat ; gem . Furthermore, to have a better understanding of our method, we evaluate our approach on continually learning a sequence of datasets with different distributions using the identical sequence as in hat , which includes FaceScrub facescrub , MNIST, CIFAR100, NotMNIST notmnist , SVHN svhn , CIFAR10, TrafficSigns traffic , and FashionMNIST fmnist . Details of each are summarized in Table 3 in Appendix. No data augmentation of any kind has been used in our analysis.

Baselines:

Within the Bayesian framework, we compare to three models which do not incorporate the importance of parameters, namely fine-tuning, feature extraction, and joint training. In fine-tuning (

-), training continues upon arrival of new tasks without any forgetting avoidance strategy. Feature extraction, denoted as (-), refers to freezing all layers in the network after training the first task and training only the last layer for the remaining tasks. In joint training (-) we learn all the tasks jointly in a multitask learning fashion which serves as the upper bound for average accuracy on all tasks, as it does not adhere to the continual learning scenario. We also perform the counterparts for FT, FE, and JT using ordinary neural networks and denote them as ORD-FT, ORD-FE, and ORD-JT. From the prior work, we compare with state-of-the-art approaches including Elastic Weight Consolidation (EWC) ewc

, Incremental Moment Matching (IMM)

imm , Learning Without Forgetting (LWF) lwf , Less-Forgetting Learning (LFL) lfl , PathNet pathnet , Progressive neural networks (PNNs) pnn , and Hard Attention Mask (HAT) hat using implementations provided by hat . On Permuted MNIST results for GEM gem and SI SI are reported from hat . Results for VCL vcl are directly adapted from original work without re-implementation.

Implementation details:

We have provided full details of our experimental setting including network architectures, hyperparameter tuning method, training parameters, pruning procedure, and choice of the number of Monte Carlo samples in Section B of the appendix.

Performance measurement: Let be the total number of tasks. Once all are learned, we evaluate our model on all tasks. ACC is the average test classification accuracy across all tasks. To measure forgetting we report backward transfer, BWT, which indicates how much learning new tasks has influenced the performance on previous tasks. While directly reports catastrophic forgetting, indicates that learning new tasks has helped with the preceding tasks. Formally, BWT and ACC are as follows:

(8)

where is the test classification accuracy on task after sequentially finishing learning the task. Note that in UCB-, refers the test accuracy on task before pruning and after pruning which is equivalent to the end of sequence performance. In Section 6, we show that our UCB model can be used when tasks labels are not available at inference time by training it with a “single head” architecture with a sum of number of classes for all tasks. We refer to the ACC measured for this scenario as “Generalized Accuracy”.

Method BWT ACC
PackNet packnet
LWF lwf
HAT hat
ORD-FT
ORD-FE
BBB-FT
BBB-FE
UCB-P (Ours)
UCB (Ours)
ORD-JT
BBB-JT
(a) Split MNIST, two tasks.
Method #Params BWT ACC
GEM gem -
SI SI M -
EWC ewc M -
VCL vcl -
HAT hat M -
UCB (Ours) M
LWF lwf M
IMM imm M
HAT hat M
BBB-FT M
BBB-FE M
UCB-P (Ours) M
UCB (Ours) M
BBB-JT M
(b) Permuted MNIST, permutations.
Method BWT ACC
PathNet pathnet
LWF lwf
LFL lfl
IMM imm
PNN pnn
EWC ewc
HAT hat
BBB-FE
BBB-FT
UCB-P (Ours)
UCB (Ours)
BBB-JT
(c) Alternating CIFAR10/100
Method BWT ACC
LFL lfl
PathNet pathnet
LWF lwf
IMM imm
EWC ewc
PNN pnn
HAT hat
BBB-FT
BBB-FE
UCB-P (Ours)
UCB (Ours)
BBB-JT
(d) Sequence of tasks
Table 1: Continually learning on different datasets. BWT and ACC in %. (*) denotes that methods do not adhere to the continual learning setup: BBB-JT and ORD-JT serve as the upper bound for ACC for BBB/ORD networks, respectively. denotes results reported by hat . denotes the result reported from original work. BWT was not reported in and . All others results are (re)produced by us and are averaged over runs with standard deviations given in Section B of the appendix.

5.2 Split MNIST

We first present our results for class incremental learning of MNIST (Split MNIST) in which we learn the digits in two tasks with randomly shuffled classes at a time. Table 1(a) shows the results for reference baselines in Bayesian and non-Bayesian neural networks including fine-tuning (-, -), feature extraction (-, -) and, joint training (-, -) averaged over runs and standard deviations are given in Table 7 in the Appendix. Although the MNIST dataset is an “easy” dataset, we observe throughout all experiments that Bayesian fine-tuning and joint training perform significantly better than their counterparts, - and -. We also compare against PackNet, HAT, and LWF where PackNet, HAT, UCB-, and UCB have zero forgetting while UCB has marginally higher accuracy than all others. (Figure 2 in the appendix shows the final accuracy upon finishing both tasks for UCB and baselines.)

5.3 Permuted MNIST

Permuted MNIST is a popular variant of the MNIST dataset to evaluate continual learning approaches in which each task is considered as a random permutation of the original MNIST pixels. Following the literature, we learn a sequence of random permutations and report average accuracy at the end. Table 1(b) shows ACC and BWT of UCB and UCB- in comparison to state-of-the-art models using a small and a large network with and parameters, respectively (architecture details are given in Section B of the appendix). The accuracy achieved by UCB (ACC=) using the small network outperforms the ACC reported in hat for GEM (ACC=), SI (ACC=), EWC (ACC=), VCL (ACC=), while HAT attains (ACC=). However, comparing the results for the larger network, while HAT and UCB have zero forgetting, UCB performs better than all baselines reaching ACC= including HAT which obtains ACC= using parameters. We also observe again that -, despite being not specifically penalized to prevent forgetting, exhibits reasonable negative BWT values, performing better than IMM and LWF baselines. It is close to joint training, -, with ACC=, which can be seen as an upper bound. The individual performance of the tasks is shown in Figure 3 in the appendix.

5.4 Alternating CIFAR10 and CIFAR100

In this experiment, we randomly alternate between class incremental learning of CIFAR10 and CIFAR100. Both datasets are divided into tasks each with and classes per task, respectively. Table 1(c) presents ACC and BWT obtained with UCB-, UCB, and three reference methods compared against various continual learning baselines. Among the baselines presented in Table 1(c), PNN and PathNet are the only zero-forgetting-guaranteed approaches. It is interesting to note that in this setup, some baselines (PathNet, LWF, and LFL) do not perform better than the naive accuracy achieved by feature extraction. PathNet suffers from bad pre-assignment of the network’s capacity per task which causes poor performance on the initial task from which it never recovers. IMM performs almost similar to fine-tuning in ACC, yet forgets more. PNN, EWC, and HAT are the only baselines that perform better than - and -. EWC and HAT are both allowed to forget by construction, however, HAT shows zero forgetting behavior. While EWC is outperformed by both of our UCB variants, HAT exhibits better ACC over UCB-. Despite having a slightly higher forgetting, the overall accuracy of UCB is higher, reaching .

5.5 Multiple datasets learning

Finally, we present our results for continual learning of tasks using UCB- and UCB in Table 1(d). Similar to the previous experiments we look at both ACC and BWT obtained for UCB-, UCB, BBB references (FT, FE, JT) as well as various baselines. Considering the ACC achieved by - or - () as a lower bound we observe again that some baselines are not able to do better than - including LFL, PathNet, LWF, IMM, and EWC while PNN and HAT remain the only strong baselines for our UCB- and UCB approaches. UCB- again outperforms PNN by in ACC. HAT exhibits only BWT, but our UCB achieves higher ACC.

6 Single Head and Generalized Accuracy of UCB

Generalized ACC ACC
Single Head Single Head Multi Head
Exp UCB BBB-FT UCB BBB-FT UCB BBB-FT
SM
PM
CF
8T
Table 2: Single Head vs. Multi-Head architecture and Generalized vs. Standard Accuracy. Generalized accuracy means that task information is not available at test time. SM, PM, CF, and 8T denote the Split MNIST, Permuted MNIST, Alternating CIFAR10/100, and sequence of tasks, respectively.

UCB can be used even if the task information is not given at test time. For this purpose, at training time, instead of using a separate fully connected classification head for each task, we use a single head with the total number of outputs for all tasks. For example in the -dataset experiment we only use one head with number of output classes, rather than using separate heads, during training and inference time. Table 2 presents our results for UCB and - trained with a single head against having a multi-head architecture, in columns -. Interestingly, we see only a small performance degrade for UCB from training with multi-head to a single head. The ACC reduction is , , , and for Split MNIST, Permuted MNIST, Alternating CIFAR10/100, and sequence of tasks experiments, respectively.

We evaluated UCB and - with a more challenging metric where the prediction space covers the classes across all the tasks. Hence, confusion of similar class labels across tasks can be measured. Performance for this condition is reported as Generalized ACC in Table 2 in columns -. We observe a small performance reduction in going from ACC to Generalized ACC, suggesting non-significant confusion caused by the presence of more number of classes at test time. The performance degradation from ACC to Generalized ACC is , , , and for Split MNIST, Permuted MNIST, Alternating CIFAR10/100, and sequence of tasks, respectively. This shows that UCB can perform competitively in more realistic conditions such as unavailability of task information at test time. We believe the main insight of our approach is that instead of computing additional measurements of importance, which are often task, input or output dependent, we directly use predicted weight uncertainty to find important parameters. We can freeze them using a binary mask, as in UCB-, or regularize changes conditioned on current uncertainty, as in UCB.

7 Conclusion

In this work, we propose a continual learning formulation with Bayesian neural networks, called UCB, that uses uncertainty predictions to perform continual learning: important parameters can be either fully preserved through a saved binary mask (UCB-) or allowed to change conditioned on their uncertainty for learning new tasks (UCB). We demonstrated how the probabilistic uncertainty distributions per weight are helpful to continually learning short and long sequences of benchmark datasets compared against baselines and prior work. We show that UCB performs superior or on par with state-of-the-art models such as HAT hat across all the experiments. Choosing between the two UCB variants depends on the application scenario: While UCB- enforces no forgetting after the initial pruning stage by saving a small binary mask per task. UCB does not require additional memory and allows for more learning flexibility in the network by allowing small forgetting to occur. UCB can also be used in a single head setting where the right subset of classes belonging to the task is not known during inference leading to a competitive model that can be deployed where it is not possible to distinguish tasks in a continuous stream of the data at test time.

References

  • (1) R. Aljundi, F. Babiloni, M. Elhoseiny, M. Rohrbach, and T. Tuytelaars. Memory aware synapses: Learning what (not) to forget. In

    Proceedings of the European Conference on Computer Vision (ECCV)

    , pages 139–154, 2018.
  • (2) R. Aljundi, K. Kelchtermans, and T. Tuytelaars. Task-free continual learning. arXiv preprint arXiv:1802.05800, 2018.
  • (3) D. M. Blei, A. Kucukelbir, and J. D. McAuliffe. Variational inference: A review for statisticians. Journal of the American Statistical Association, 112(518):859–877, 2017.
  • (4) C. Blundell, J. Cornebise, K. Kavukcuoglu, and D. Wierstra. Weight uncertainty in neural network. In F. Bach and D. Blei, editors,

    Proceedings of the 32nd International Conference on Machine Learning

    , volume 37 of Proceedings of Machine Learning Research, pages 1613–1622. PMLR, 2015.
  • (5) Y. Bulatov. Notmnist dataset. Google (Books/OCR), Tech. Rep.[Online]. Available: http://yaroslavvb. blogspot. it/2011/09/notmnist-dataset. html, 2011.
  • (6) A. Chaudhry, M. Ranzato, M. Rohrbach, and M. Elhoseiny. Efficient lifelong learning with A-GEM. In International Conference on Learning Representations, 2019.
  • (7) Y. Chen, T. Diethe, and N. Lawrence. Facilitating bayesian continual learning by natural gradients and stein gradients. arXiv preprint arXiv:1904.10644, 2019.
  • (8) C. Fernando, D. Banarse, C. Blundell, Y. Zwols, D. Ha, A. A. Rusu, A. Pritzel, and D. Wierstra. Pathnet: Evolution channels gradient descent in super neural networks. arXiv preprint arXiv:1701.08734, 2017.
  • (9) S. Han, H. Mao, and W. J. Dally. Deep compression: Compressing deep neural networks with pruning, trained quantization and huffman coding. arXiv preprint arXiv:1510.00149, 2015.
  • (10) W. K. Hastings. Monte carlo sampling methods using markov chains and their applications. Biometrika, 1970.
  • (11) G. E. Hinton and D. Van Camp. Keeping the neural networks simple by minimizing the description length of the weights. In

    Proceedings of the sixth annual conference on Computational learning theory

    , pages 5–13. ACM, 1993.
  • (12) T. Jaakkola and M. Jordan.

    A variational approach to bayesian logistic regression models and their extensions.

    In

    Sixth International Workshop on Artificial Intelligence and Statistics

    , volume 82, page 4, 1997.
  • (13) T. S. Jaakkola and M. I. Jordan. Computing upper and lower bounds on likelihoods in intractable networks. In Proceedings of the Twelfth international conference on Uncertainty in artificial intelligence, pages 340–348. Morgan Kaufmann Publishers Inc., 1996.
  • (14) H. Jung, J. Ju, M. Jung, and J. Kim. Less-forgetting learning in deep neural networks. arXiv preprint arXiv:1607.00122, 2016.
  • (15) M. E. Khan and D. Nielsen. Fast yet simple natural-gradient descent for variational inference in complex models. In 2018 International Symposium on Information Theory and Its Applications (ISITA), pages 31–35. IEEE, 2018.
  • (16) J. Kirkpatrick, R. Pascanu, N. Rabinowitz, J. Veness, G. Desjardins, A. A. Rusu, K. Milan, J. Quan, T. Ramalho, A. Grabska-Barwinska, et al. Overcoming catastrophic forgetting in neural networks. Proceedings of the national academy of sciences, page 201611835, 2017.
  • (17) A. Krizhevsky and G. Hinton. Learning multiple layers of features from tiny images. Technical report, Citeseer, 2009.
  • (18) Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
  • (19) S.-W. Lee, J.-H. Kim, J. Jun, J.-W. Ha, and B.-T. Zhang. Overcoming catastrophic forgetting by incremental moment matching. In Advances in Neural Information Processing Systems, pages 4652–4662, 2017.
  • (20) Z. Li and D. Hoiem. Learning without forgetting. In European Conference on Computer Vision, pages 614–629. Springer, 2016.
  • (21) Z. Liu, J. Li, Z. Shen, G. Huang, S. Yan, and C. Zhang. Learning efficient convolutional networks through network slimming. In Computer Vision (ICCV), 2017 IEEE International Conference on, pages 2755–2763. IEEE, 2017.
  • (22) D. Lopez-Paz et al. Gradient episodic memory for continual learning. In Advances in Neural Information Processing Systems, pages 6467–6476, 2017.
  • (23) A. Mallya and S. Lazebnik. Packnet: Adding multiple tasks to a single network by iterative pruning. In

    IEEE Conference on Computer Vision and Pattern Recognition (CVPR)

    , 2018.
  • (24) J. L. McClelland, B. L. McNaughton, and R. C. O’reilly. Why there are complementary learning systems in the hippocampus and neocortex: insights from the successes and failures of connectionist models of learning and memory. Psychological review, 102(3):419, 1995.
  • (25) M. McCloskey and N. J. Cohen. Catastrophic interference in connectionist networks: The sequential learning problem. In Psychology of learning and motivation, volume 24, pages 109–165. Elsevier, 1989.
  • (26) P. Molchanov, S. Tyree, T. Karras, T. Aila, and J. Kautz.

    Pruning convolutional neural networks for resource efficient inference.

    In International Conference on Learning Representations (ICLR), 2016.
  • (27) Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng. Reading digits in natural images with unsupervised feature learning. In NIPS workshop on deep learning and unsupervised feature learning, 2011.
  • (28) H.-W. Ng and S. Winkler. A data-driven approach to cleaning large face datasets. In Image Processing (ICIP), 2014 IEEE International Conference on, pages 343–347. IEEE, 2014.
  • (29) C. V. Nguyen, Y. Li, T. D. Bui, and R. E. Turner. Variational continual learning. In ICLR, 2018.
  • (30) C. Peterson. A mean field theory learning algorithm for neural networks. Complex systems, 1:995–1019, 1987.
  • (31) C. Robert and G. Casella. Monte Carlo statistical methods. Springer Science & Business Media, 2013.
  • (32) A. Robins. Catastrophic forgetting, rehearsal and pseudorehearsal. Connection Science, 7(2):123–146, 1995.
  • (33) A. A. Rusu, N. C. Rabinowitz, G. Desjardins, H. Soyer, J. Kirkpatrick, K. Kavukcuoglu, R. Pascanu, and R. Hadsell. Progressive neural networks. arXiv preprint arXiv:1606.04671, 2016.
  • (34) J. Schwarz, W. Czarnecki, J. Luketina, A. Grabska-Barwinska, Y. W. Teh, R. Pascanu, and R. Hadsell. Progress & compress: A scalable framework for continual learning. In J. Dy and A. Krause, editors, Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pages 4528–4537. PMLR, 2018.
  • (35) J. Serra, D. Suris, M. Miron, and A. Karatzoglou. Overcoming catastrophic forgetting with hard attention to the task. In J. Dy and A. Krause, editors, Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pages 4548–4557. PMLR, 2018.
  • (36) R. K. Srivastava, J. Masci, S. Kazerounian, F. Gomez, and J. Schmidhuber. Compete to compute. In Advances in neural information processing systems, pages 2310–2318, 2013.
  • (37) J. Stallkamp, M. Schlipsing, J. Salmen, and C. Igel. The german traffic sign recognition benchmark: a multi-class classification competition. In Neural Networks (IJCNN), The 2011 International Joint Conference on, pages 1453–1460. IEEE, 2011.
  • (38) H. Tseran, M. E. Khan, T. Harada, and T. D. Bui. Natural variational continual learning. 2018.
  • (39) H. Xiao, K. Rasul, and R. Vollgraf. Fashion-mnist: a novel image dataset for benchmarking machine learning algorithms. arXiv preprint arXiv:1708.07747, 2017.
  • (40) J. Yoon, E. Yang, J. Lee, and S. J. Hwang. Lifelong learning with dynamically expandable networks. In International Conference on Learning Representations, 2018.
  • (41) F. Zenke, B. Poole, and S. Ganguli. Continual learning through synaptic intelligence. In D. Precup and Y. W. Teh, editors, Proceedings of the 34th International Conference on Machine Learning, volume 70 of Proceedings of Machine Learning Research, pages 3987–3995. PMLR, 2017.

Appendix

A. Datasets

Table 3 shows a summary of the datasets utilized in our work along with their size and number of classes. In all the experiments we resized images to if necessary. For datasets with monochromatic images, we replicate the image across all RGB channels.

Names Classes Train Test
FaceScrub facescrub 100 20,600 2,289
MNIST mnist 10 60,000 10,000
CIFAR100 cifar 100 50,000 10,000
NotMNIST notmnist 10 16,853 1,873
SVHN svhn 10 73,257 26,032
CIFAR10 cifar 10 39,209 12,630
TrafficSigns traffic 43 39,209 12,630
FashionMNIST fmnist 10 60,000 10,000
Table 3: Utilized datasets summary

B. Implementation Details

Network architecture:

For Split MNIST and Permuted MNIST experiments, we have used a two-layer perceptron which has

units. Because there is more number of parameters in our Bayesian neural network compared to its equivalent regular neural net, we ensured fair comparison by matching the total number of parameters between the two to be unless otherwise is stated. For the multiple datasets learning scenario, as well as alternating incremental CIFAR10/100 datasets, we have used a ResNet18 Bayesian neural network with - parameters depending on the experiment. However, the majority of the baselines provided in this work are originally developed using some variants of AlexNet structure and altering that, e.g. to ResNet18, resulted in degrading in their reported and experimented performance as shown in Table 4. Therefore, we kept the architecture for baselines as AlexNet and ours as ResNet18 and only matched their number of parameters to ensure having equal capacity across different approaches.

Method BWT ACC
HAT (AlexNet)
HAT (ResNet18)
UCB (AlexNet)
UCB (ResNet18)
Table 4: Continually learning on CIFAR10/100 using AlexNet and ResNet18 for UCB (our method) and HAT hat . BWT and ACC in %. All results are (re)produced by us.

Hyperparameter tuning: Unlike commonly used tuning techniques which use a validation set composed of all classes in the dataset, we only rely on the first two task and their validations set, similar to the the setup in agem . In all our experiments we consider a split for the validation set on the first two tasks. After tuning, training starts from the beginning of the sequence. Our scheme is different from agem , where the models are trained on the first (e.g. three) tasks for validation and then training is restarted for the remaining ones and the reported performance is only on the remaining tasks.

Training details: It is important to note that in all our experiments, no pre-trained model is used

. We used stochastic gradient descent with a batch size of

and a learning rate of , decaying it by a factor of once the loss plateaued. Dataset splits and batch shuffle are identically in all UCB experiments and all baselines.

Pruning procedure: Once a task is learned, we compute the performance drop for a set of arbitrary pruning percentages from the maximum training accuracy achieved when no pruning is applied. The pruning portion is then chosen using a threshold beyond which the performance drop is not accepted. Depending on the dataset, we have used a range of drop to be an acceptable performance drop.

Parameter regularization and importance measurement: Table 5 ablates different ways to compute the importance of an parameter in Eq. 4 and 5. As shown in Table 5 the configuration that yields the highest accuracy and the least forgetting (maximum BWT) occurs when the learning rate regularization is performed only on of the posteriors using as the importance and .

Method Importance BWT (%) ACC (%)
UCB x -
UCB - x
UCB x x
UCB x -
UCB - x
UCB x x
UCB-P x x
UCB-P x x
Table 5: Variants of learning rate regularization and importance measurement on Split MNIST

Number of Monte Carlo samples: UCB is ensured to be robust to random noise using multiple samples drawn from posteriors. Here we explore different number of samples and the effect on final performance for ACC and BWT. We have used as importance and regularization has been performed on mean values only. Following the result in Table 6 we chose the number of samples to be for all experiments.

Method BWT (%) ACC (%)
UCB
UCB
UCB
UCB
UCB
Table 6: Number of Monte Carlo samples (N) in Split MNIST

B. Additional results

Here we include some complementary results for tables in the main text as follows: Table 7, 8, and 9 include standard deviation for results shown in Table 1(a), 1(b), 1(c), respectively.

Method BWT ACC
PackNet packnet
LWF lwf
HAT hat
ORD-FT
ORD-FE
BBB-FT
BBB-FE
UCB-P (Ours)
UCB (Ours)
ORD-JT
BBB-JT
Table 7: Continually learning on Split MNIST. BWT and ACC in %. (*) denotes that methods do not adhere to the continual learning setup: BBB-JT and ORD-JT serve as the upper bound for ACC for BBB/ORD networks, respectively. All results are (re)produced by us.
Method #Params BWT ACC
GEM gem -
SI SI -
EWC ewc -
VCL vcl -
HAT hat -
UCB (Ours)

LWF lwf
IMM imm
HAT hat
BBB-FT
BBB-FE
UCB-P (Ours)
UCB (Ours)
BBB-JT
Table 8: Continually learning on Permuted MNIST. BWT and ACC in %. (*) denotes that method does not adhere to the continual learning setup: BBB-JT serves as the upper bound for ACC for BBB network. denotes results reported by hat . denotes the result reported from original work. BWT was not reported in and . All others results are (re)produced by us.
Method BWT ACC
PathNet pathnet
LWF lwf
LFL lfl
IMM imm
PNN pnn
EWC ewc
HAT hat
BBB-FE
BBB-FT
UCB-P (Ours)
UCB (Ours)
BBB-JT
Table 9: Continually learning on CIFAR10/100. BWT and ACC in %. (*) denotes that method does not adhere to the continual learning setup: BBB-JT serves as the upper bound for ACC for BBB network. All results are (re)produced by us.

In Figures 2, 3, 4, and 5 we illustrate how the accuracy of task carries across tasks for different experiments for MNIST Split, Permuted MNIST, alternating CIFAR10/100, and sequence of tasks, respectively.

Figure 2: Accuracy of task at the end after training Split MNIST ().
Figure 3: Accuracy of task at the end after training permuted MNIST ().
Figure 4: Accuracy of task at the end after training the full sequence of alternating incremental CIFAR10/CIFAR100 ().
Figure 5: Accuracy of task at the end after training the full sequence of tasks ().