Learning to Continually Learn Rapidly from Few and Noisy Data

03/06/2021 ∙ by Nicholas I-Hsien Kuo, et al. ∙ Australian National University 0

Neural networks suffer from catastrophic forgetting and are unable to sequentially learn new tasks without guaranteed stationarity in data distribution. Continual learning could be achieved via replay – by concurrently training externally stored old data while learning a new task. However, replay becomes less effective when each past task is allocated with less memory. To overcome this difficulty, we supplemented replay mechanics with meta-learning for rapid knowledge acquisition. By employing a meta-learner, which learns a learning rate per parameter per past task, we found that base learners produced strong results when less memory was available. Additionally, our approach inherited several meta-learning advantages for continual learning: it demonstrated strong robustness to continually learn under the presence of noises and yielded base learners to higher accuracy in less updates.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

Introduction

In standard practice, it is assumed that the examples used to train a neural network are drawn independently and identically distributed (i.i.d.) from some fixed distribution. However in a changing environment, continual learning is required where an agent faces a continual stream of data, and must adapt to learn new predictions. This remains a difficult challenge because neural networks are known to suffer from catastrophic forgetting McCloskey and Cohen (1989), the phenomenon which they abruptly lose the capability to solve a previously learnt task when information for solving a new task is incorporated.

It was shown that catastrophic forgetting can be effectively alleviated by concurrently training data of past tasks. These replay Robins (1995) and pseudo replay Shin et al. (2017) mechanisms have continuously been improved Rebuffi et al. (2017); Lopez-Paz and Ranzato (2017); Chaudhry et al. (2018); and a recent paper from Chaudhry et al. (2019) have even shown that forget prevention could be made possible by co-training as little as one single sample of past data per parameteric update. Nonetheless, we found that the replayed data was still required to be sampled from a large set of externally stored memory. Sampling from an insufficient amount of external memory would lead a network to overfit on early tasks thus generalise poorly on subsequent tasks.

This paper combines the work of Chaudhry et al. (2019) with meta-learning for rapid knowledge acquisition. We extended their method with MetaSGD Li et al. (2017) to maintain high performances under the constraint of learning with low memory consumption. We employed a meta-learner to assist the training procedure of a task-specific base learner of which learns a learning rate per parameter per past task. This hybrid approach not only produced strong results when less memory was available, but it also inherited several desirable meta-learning advantages for continual learning. Our approach demonstrated strong robustness against the realistic scenario which we continually learn under the presence of noises, and it could also optimise base learner parameters to achieve higher accuracy in less iterations.

Preliminaries

Our narrative below mainly follows Lopez-Paz and Ranzato (2017), Kirkpatrick et al. (2017), and Chaudhry et al. (2019).

The Framework of Continual Learning

In continual learning, we employ a task-specific base learner to sequentially learn the tasks of in a continuum of data

(1)

to map inputs to labels from datasets for all data where . The aim is to yield a learner that generalises well on subsequent tasks while it maintains high accuracy for learnt tasks.

Continual learning is challenging Kemker et al. (2018). When a learner only observes data of a new task, its optimised parameters for an old task will be modified to suit the learning objective of the new task. We elaborate this with the scenario presented in Figure 1 where a learner optimised for Task U is now sequentially updated for an independent Task V. Naïvely training the learner with gradient descent will cause the learner to leave the minimum (low loss region) on the loss landscape of Task U in pursuit for that of the new task Kirkpatrick et al. (2017)111Inspiration taken from Figure 1 of Kirkpatrick et al. (2017)..

Figure 1: Exiting a local minimum while continual learning.

Replay with Episodic Memory

Catastrophic forgetting can be alleviated by replaying all data of past tasks but this inevitably creates a large computational and memory burden. Several studies have shown that memory-based strategy do not require all past training data. Rebuffi et al. (2017) prevented forgetting by constructing an exemplar set of past data via learning representation; while Lopez-Paz and Ranzato (2017) preconditioned the gradients of the current learning task with small episodic memories of past tasks. This paper will focus specifically on episodic memory-based methods as Lopez-Paz and Ranzato (2017)’s original work was continually streamlined by the important works of Chaudhry et al. (2018) and Chaudhry et al. (2019).

There are two steps to the episodic memory mechanism. Prior to the sequential training phase, Lopez-Paz and Ranzato (2017) first define the size of the memory storage allocated for each task. Then upon training a new task, all stored data from the external memory was used to precondition the gradients of the current task. Though this method yielded very strong results and only stored a section of all past data, it still had difficulties in scaling well as the amount of subsequent tasks increased.

In their paper, Chaudhry et al. (2019) streamlined Lopez-Paz and Ranzato (2017)’s method with Experience Replay (ER), and showed that catastrophic forgetting can be alleviated without utilising all data stored within episodic memories. Algorithm 1 shows their ER approach; and the core mechanism of ER is to sample mini-batches from all stored data of past tasks (see line 8) to effectively mitigate the computational burden from employing episodic memories. The sampled content was then concurrently trained with the existing batch of the current learning task and they used SGD to update parameters of base learner (see line 10). Impressively, they found that strong results for forget prevention can be achieved even when only sampled 1 data from all memory .

1 Define memory storage
2
3 for  do
4       for   do
5             Update memory
6            
7             if  then
8                   Sample from all memory
9                  
10                   Update with sampled memory
11                  
12             else
13                   Conventional update
14                  
15             end if
16            
17       end for
18      
19 end for
Algorithm 1 Experience replay recipe

Learning with Episodic Memory Sampled
from a Tiny Memory Storage

To employ episodic memory, the practitioner would need to explicitly define the size of the external memory storage;
but what are the disadvantages of selecting small sizes?
A small memory storage meant that the replay mini-batch would have a higher chance of re-sampling the same past data for concurrent training (see line 11 of Algorithm 1). A base learner would hence overfit on earlier tasks through the consecutive re-exposure of a small set of past data and thus underfit on future tasks. This issue was lightly investigated in Chaudhry et al. (2019)222 See Section 5.5 and Appendix C of Chaudhry et al. (2019)., but they only compared medium-size to large-size memories333 See Table 7 in Chaudhry et al. (2019), where they compared 1000 to 20000 memory units. What happens when we use 100?. In this current work, we limit the memory size to only a handful of data per past task to study its consequences for continual learning and propose meta-learning as a remedy.

Meta-learning

Meta-learning studies rapid knowledge acquisition and has contributed to few shot learning Ravi and Larochelle (2017) where a base learner is required to learn from a low resource environment. In this paper, we formulate continual learning with tiny episodic memory storage as a low resource problem. Though sequential learning does not lack data on subsequent tasks, our base learner is still required to learn from a small memory unit for past tasks.

Learning to Optimise with MetaSGD

One archetype of meta-learning focuses on optimisation. This type of approach is also known as learning to learn Andrychowicz et al. (2016). Optimisation is cast as the learning problem of a meta-learner of which is used to configure the weights of the base learner network. It replaces the classic gradient descent of

(2)

for a base learner with parameters at time ; where is the learning rate and that is the gradient of the learner loss with respect to the parameters.

One particularly simple method called MetaSGD Li et al. (2017) proposes to update learner parameters via

(3)

where and that symbol represents the element-wise product. Learning rates are the meta-learning target of MetaSGD, and hence a unique learning rate is learnt for every parameter to enable important features to be updated more quickly. In their paper, Li et al. (2017) showed that MetaSGD achieved strong results for few shot learning; and Kuo et al. (2020) showed that MetaSGD could converged base learners much more rapidly than SGD444 See Figure 5 of Kuo et al. (2020) where MetaSGD-like methods decreased updates from 37500 to 10000 steps.. We will now simplify our notation by denoting as .

An Analysis on the Update of Experience Replay for Continual Learning

Under conventional setup, loss aggregates every individual cost therein a mini-batch with

(4)

where is the prediction of learner and that is the corresponding true label. Hence under the ER regime, every batch used in continual learning will be substituted with where (see line 11 of Algorithm 1). Then due to the linearity of summation, the new loss can be written as

(5)

If a base learner had observed tasks, we can highlight task specificity and rewrite Equation (5) as

(6)

where . Then with as the task-wise loss we have

(7)

and this will create the standard parametric update of

(8)

of which we could employ MetaSGD and introduce as

(9)

to learn a learning rate per parameter per past task.

Figure 2: Task-wise directional updates.

To elaborate, consider the scenario in Figure 2 where a learner observes its third task. The learner faces different directions of update from the competing learning objectives of different tasks. Interference Riemer et al. (2018) hence occurs and transfer in knowledge is difficult; this is further complicated as learner could easily surpass thousands of parameters in practice Szegedy et al. (2015). Thus instead of applying handcrafted rules such as SGD to update a base learner, we employ MetaSGD for Continual Learning (MetaSGD-CL) to explore a stream of changes in a series of task structures.

Implementation Details on MetaSGD-CL

Upon learning a new task

(including task 1), MetaSGD-CL initialises a trainable vector

. In the context of Equation (9), we set

(10)

where is a postive constant hyper-parameter. Thus every individual learning rate lies within the range of . In this paper, we default as 0.02 unless specified otherwise.

When sequential learning a new task (excluding task 1), MetaSGD-CL freezes past learning rates for all and updates base learner with the formulation below

(11)
Figure 3:

Performances for permuted MNIST. Top row to reflect catastrophic forgetting; and bottom row to reflect overfitting.

Equation (11) accounts for alignments in due to similarity among arbitrary task structures; hence prevents the over-parametrisation in any given feature dimension.

The loss function of MetaSGD-CL is adapted from

Andrychowicz et al. (2016)

(12)

based on an updated learner with original data . Thus, the objective for MetaSGD-CL is to further fine-tune learner . The meta-learning targets s are updated with Adam Kingma and Ba (2015) with learning rate 0.01.

Experiments

This section includes some benchmarks, demonstrates the benefits of MetaSGD-CL over ER, and an ablation study.

Dataset

We conducted experiments on permuted MNIST Kirkpatrick et al. (2017) and incremental Cifar100 Lopez-Paz and Ranzato (2017). Previously, Lopez-Paz and Ranzato (2017) and Chaudhry et al. (2019) showed that memory-based approaches outperformed continual learning techniques of other archetypes and achieved state of the art status. However, we will show that when the amount of memory decreases, ER performances can drop sharply on the considerably simple permuted MNIST tasks.

MNIST LeCun et al. (1998) contains 60,000 training and 10,000 test images in 10 classes of 2828 grey-scale handwritten digits. Permuted MNIST first reshaped each image as a vector of 784(=2828) pixels, then applied a unique permutation on all vectors of that task. In our experiments, we sequentially trained over 10 permuted MNIST tasks.

Cifar100 Krizhevsky (2009) is a coloured dataset with 100 classes of objects. There are 50,000 training and 10,000 test images all presented on 32

32 pixel boxes. In our incremental class experiments, we sequentially introduced 5 classes per task drawn without replacement from the dataset. We reserved the first 50 classes for pre-training a classifier, and tested on the remaining 50 classes, hence there were 10 tasks in total. See the next subsection for more details.

Methods FA1 (%) ACC (%)
MetaSGD-CL (Ours) 81.02 82.19
ER 80.60 69.42
GEM 80.71 79.43
EWC 64.80 71.84
HAT 74.03 76.17
Singular 62.18 71.34
Table 1: Metrics for permuted MNIST in Figure 3.

Base Learners

For permuted MNIST, we followed Lopez-Paz and Ranzato (2017) and trained MLPs as base learner. As for incremental Cifar100, we followed Mai et al. (2020)555Mai et al. (2020) won the CVPR CLVision 2020 challenge. and used their best practice which employed a fixed feature extractor and only sequentially trained the classifier networks. We pre-trained a ResNet18 He et al. (2016)

as feature extractor on the first 50 classes of Cifar100; then updated MLP base learners as classifiers. All MLP base learners of our study had 2 hidden layers with 100 dimensions followed by a softmax layer.

Following Lopez-Paz and Ranzato (2017)

, we used batch size 10. Each task of permuted MNIST observed 100 batches; and it was 1 epoch each for incremental Cifar100. Our MetaSGD-CL was described in the last section; and all baselines were updated with SGD with learning rate 0.01.

Metrics

The 2 metrics of which we use in this paper are based on a matrix of which we employ to store all accuracy during the entire course of continual learning. Row recorded the accuracy after training on the th dataset; and column recorded the accuracy of the th task.

The severity of catastrophic forgetting can be reflected via
Final Accuracy of Task 1   (FA1)     ,            ; while
Average accuracy            (ACC)    
can show the overall performance and the severity of overfitting of the base learner. For both of our metrics, it is the larger the better.

Figure 4: Performances for incremental Cifar100.
Methods FA1 (%) ACC (%)
MetaSGD-CL (Ours) 81.52 81.06
ER 79.28 79.50
GEM 79.92 80.88
EWC 67.46 71.98
HAT 71.98 78.86
Singular 62.86 73.27
Table 2: Metrics for incremental Cifar100 in Figure 4.

Continual Learning Techniques

Our benchmark experiments included 6 setups. Singular was the scenario where we naïvely sequentially trained a base learner. Then, we tested our MetaSGD-CL against the baseline ER Chaudhry et al. (2019). In addition, we considered the memory-based Gradient Episodic Memory (GEM) of Lopez-Paz and Ranzato (2017), the regularisation-based Elastic Weight Consolidation (EWC) of Kirkpatrick et al. (2017), and the dynamic architectural-based Hard Attention to the Task (HAT) of Serra et al. (2018).

Remark 1: Preventing Catastrophic Forgetting

Our results on permuted MNIST are in Figure 3 and Table 1; and those for incremental Cifar100 are in Figure 4666 Refer to our codes (link in Abstract) to reformat Figure 4 as Figure 3 for further inspection. and Table 2. Base on the FA1 metric in the 2 tables, we validated Lopez-Paz and Ranzato (2017) and Chaudhry et al. (2019)’s finding such that memory-based approaches were more effective in preventing catastrophic forgetting.

Our memory-based approach used Hard Storage Lopez-Paz and Ranzato (2017), which assigned 250 memory units to each tasks. For each update, the less computationally efficient GEM used all stored data to prevent forgetting; while both MetaSGD-CL and ER randomly sampled 10 data from each task-wise storage to append them as to the original batch (see Equation (5)). Samples of old tasks do not change throughout training.

The base learners of HAT had hidden dimensions of 400 instead. With only 100 hidden dimensions, HAT could not prevent forgetting and behaved as that in Ahn et al. (2019).

Remark 2: 3 Advantages of Meta-SGD over ER

Figure 5: Ring buffer of different sizes for permuted MNIST.

Remark 2.1: Does not overfit on old tasks

From Figure 3 and Table 1, ER was comparable to MetaSGD-CL in FA1 but much lower in ACC. This served as a clear indication of ER’s overfitting problem. For incremental Cifar100 and Table 2, Mai et al. (2020)’s experimental setup limited some extent of overfitting by only continually training the classifier network. However, the bottom subplot of Figure 4 showed that ER’s variability (in blue) was much larger than MetaSGD-CL’s (in yellow); and hence overfitting persisted.

We further tested overfitting with the alternative memory storage of Ring Buffer Chaudhry et al. (2019). Ring buffer shares memory units for all tasks777If and that tasks, then it is units per task.. By design, the buffer is under-utilised in the early phases of sequential learning; and that stored samples do not change. Compared to hard storage, ring buffer is much more computationally efficient. However, it is also more likely to incur overfitting because data of all past tasks are now sharing the same buffer.

Figure 5 shows the results for MetaSGD-CL and ER on permuted MNIST with ring buffers of 1000, 250, and 100 units. Note, these settings were much smaller than the scenarios tested in Chaudhry et al. (2019). Similar to Remark 1, each iteration randomly sampled 10 past data as but from one unified storage (see line 11 of Algorithm 1).

Unsurprisingly, the performances of either methods decreased as the sizes of memory dropped. This was a natural consequence of overfitting by repeatedly re-sampling from a small set of past data. ER is one of the state of the art techniques; but still, there was a significant amount of lesion on the considerably simple permuted MNIST benchmark when the size of ring buffer changed from 1000 to 100. MetaSGD-CL was more capable of maintaining higher accuracy.

Remark 2.2: Rapid Knowledge Acquisition

We also tested 2 realistic continual learning scenarios. First, we considered when there was a limited amount of source data; and second, when data was collected from a noisy environment.

In Figure 6, we tested MetaSGD-CL and ER on permuted MNIST with ring buffer but reduced the available training data from 100 iterations to 25 iterations per task. Unsurprisingly, the test accuracy had dropped; in addition, overfitting continued to occur in ER. Our MetaSGD-CL approach was able to leverage on meta-learning and it yielded higher accuracy in few iterations; also, its mean accuracy concurrently increased with its accuracy for task 1.

Figure 6: Permuted MNIST but with iteration reduction.
Figure 7: Permuted MNIST but with noise injection.

Remark 2.3: Strong Robustness to Noise

In Figure 7, we tested MetaSGD-CL and ER on permuted MNIST with ring buffer, and we demonstrated the consequences of noise injection (NI) in both the training data and stored memory. NI randomly shuffled a section of pixels in the MNIST images. Our lowest setting injected 10% noise, and the highest setting injected up to 50% noise.

We found that MLPs trained with MetaSGD-CL were less susceptible to NI than their counterparts trained with ER; there were higher accuracy and smaller variability. This was likely because the learning rates of Equation (9) assigned higher learning rates to features that were less prune to NI.

Remark 3: Behaviours on Learning Rate

Remark 3.1: Convergence in values

As we previously mentioned, the maximal value of was defaulted as (see Equation (10)). The value for was purposely set low because we hypothesised that there existed conflicted objectives between meta-learning and continual learning.

As shown in Figure 6, meta-learning is naturally capable of rapid knowledge acquisition. That is, we hypothesised that would grow unboundedly with large values; and this would hence indirectly cause overfitting on old tasks.

To our surprise, yielded 76.95 ACC and this was a minor lesion to ’s 82.19 ACC shown in Table 1. We found that it was because that extreme values converged as the base learner observed more tasks. As shown in Table 3 and Figure 8, only a minor proportion of values were lower than and larger than . We found that as the base learner observed more tasks, the proportion of large magnitude task-wise decreased, while that of small magnitude task-wise increased. That is, less parametric updates were made to the base learner from data of later tasks.

Furthermore, a larger proportion of large values were found in the task-wise of deeper base learner layers. Thus MetaSGD-CL prioritised in updating the weights of the deeper layers rather than those of the shallower layers.

Proportion of (in %) in task-wise learning rate
Tasks 2 4 6 8 10
Layer 1 1.39 0.63 0.47 0.46 0.36
Layer 2 2.26 1.41 1.33 1.32 1.30
Softmax Layer 3.64 2.28 1.88 1.76 1.74
Proportion of (in %) in task-wise learning rate
Tasks 2 4 6 8 10
Layer 1 1.43 3.90 6.56 8.31 8.79
Layer 2 4.29 6.59 7.71 10.51 11.27
Softmax Layer 6.02 7.50 8.38 8.92 11.58
Table 3: Ratios of extreme task-wise s with .
Figure 8: Transition in task-wise magnitudes.
Methods FA1 (%) ACC (%)
Un-ablated MetaSGD-CL 81.02 82.19
Singular 62.18 71.34
Replace all old with 0s 69.81 74.41
Replace all old with 0.01s 73.48 77.16
Replace all old with 0.1s 72.57 75.80
Table 4: Metrics for permuted MNIST in Figure 3.

Remark 3.2: Ablation Study

To further understand MetaSGD-CL, we reported 3 ablation studies in Table 4. When learning a new task, we removed the learnt learning rates , and replaced them with 0, 0.01, and 0.1.

When of the old tasks were replaced with 0, catastrophic forgetting occurred. However, these ablated performances were still much better than the naïvely implemented sequential learning. Interestingly, when of the old tasks were replaced with non-zero values, the ablated performances were similar. We attributed this to formulation Equation (9); as long as information of the past were given, MetaSGD-CL was able to learn appropriate for the new task to indirectly balance base learner parametric configurations for all tasks.

Related Studies and Future Work

Like EWC and Learning without Forgetting (LwF) Li and Hoiem (2017), MetaSGD-CL also alleviated catastrophic forgetting via modifications in the backward pass. Thus, it can likely be modified to prevent forgetting via regularisation.

Alternative meta-learning approaches have been employed in continual learning. Riemer et al. (2018) combined GEM with Reptile Nichol et al. (2018). While Javed and White (2019) continually learned via representation learning with Model-Agnostic Meta-Learning (MAML) Finn et al. (2017)-like techniques. A large scale comparison for such hybrid approaches would thus be an interesting future study.

Conclusion

This paper introduced MetaSGD-CL as an extension to Chaudhry et al. (2019)’s ER technique to remedy the overfitting on old tasks when memory-based continual learning were allocated with small memory sizes. MetaSGD-CL leveraged on meta-learning by learning a learning rate per parameter per past task. It alleviated catastrophic forgetting and also prevented base learners from overfitting and achieved high performances for all tasks. In addition, MetaSGD-CL optimised base learners in fewer iterations, and showed higher robustness against noises in training data.

Acknowledgments

This research was supported by the Australian Government Research Training Program (AGRTP) Scholarship. We also thank our reviewers for the constructive feedback.

References