La-MAML: Look-ahead Meta Learning for Continual Learning

07/27/2020 ∙ by Gunshi Gupta, et al. ∙ Université de Montréal Carnegie Mellon University Montréal Institute of Learning Algorithms 0

The continual learning problem involves training models with limited capacity to perform well on a set of an unknown number of sequentially arriving tasks. While meta-learning shows great potential for reducing interference between old and new tasks, the current training procedures tend to be either slow or offline, and sensitive to many hyper-parameters. In this work, we propose Look-ahead MAML (La-MAML), a fast optimisation-based meta-learning algorithm for online-continual learning, aided by a small episodic memory. Our proposed modulation of per-parameter learning rates in our meta-learning update allows us to draw connections to prior work on hypergradients and meta-descent. This provides a more flexible and efficient way to mitigate catastrophic forgetting compared to conventional prior-based methods. La-MAML achieves performance superior to other replay-based, prior-based and meta-learning based approaches for continual learning on real-world visual classification benchmarks.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 16

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Embodied or interactive agents that accumulate knowledge and skills over time must possess the ability to continually learn. Catastrophic forgetting catastrophic, catastrophic2, one of the biggest challenges in this setup, can occur when the i.i.d.

sampling conditions required by stochastic gradient descent (SGD) are violated as the data belonging to different tasks to be learnt arrives sequentially. Algorithms for

continual learning (CL) must also use their limited model capacity efficiently since the number of future tasks is unknown. Ensuring gradient-alignment across tasks is therefore essential, to make shared progress on their objectives. Gradient Episodic Memory (GEM) lopez2017gradient investigated the connection between weight sharing and forgetting in CL and developed an algorithm that explicitly tried to minimise gradient interference. This is an objective that meta-learning algorithms implicitly optimise for (refer to nichol2018first for derivations of the effective parameter update made in first and second order meta learning algorithms). Meta Experience Replay (MER) riemer2018learning formalized the transfer-interference trade-off and showed that the gradient alignment objective of GEM coincide with the objective optimised by the first order meta-learning algorithm Reptile nichol2018first.

Besides aligning gradients, meta-learning algorithms show promise for CL since they can directly

use the meta-objective to influence model optimisation and improve on auxiliary objectives like generalisation or transfer. This avoids having to define heuristic incentives like sparsity

2018sr-nn for better CL. The downside is that they are usually slow and hard to tune, effectively rendering them more suitable for offline continual learning javed2019meta, riemer2018learning. In this work, we overcome these difficulties and develop a gradient-based meta-learning algorithm for efficient, online continual learning. We first propose a base algorithm for continual meta-learning referred to as Continual-MAML (C-MAML) that utilizes a replay-buffer and optimizes a meta-objective that mitigates forgetting. Subsequently, we propose a modification to C-MAML, named La-MAML, which incorporates modulation of per-parameter learning rates (LRs) to pace the learning of a model across tasks and time. Finally, we show that the algorithm is scalable, robust and achieves favourable performance on several benchmarks of varying complexity.

2 Related work

Relevant CL approaches can be roughly categorized into replay-based, regularisation (or prior-based) and meta-learning-based approaches.

In order to circumvent the issue of catastrophic forgetting, replay-based methods maintain a collection of samples from previous tasks in memory. Approaches utilising an episodic-buffer castro2018end, rebuffi2017icarl uniformly sample old data points to mimic the i.i.d. setup within continual learning. Generative-replay 2017genreplay trains generative models to be able to replay past samples, with scalability concerns arising from the difficulty of modeling complex non-stationary distributions. GEM lopez2017gradient and A-GEM chaudhry2018agem take memory samples into account to determine altered low-interference gradients for updating parameters.

Regularisation-based methods avoid using replay at all by constraining the network weights according to heuristics intended to ensure that performance on previous tasks is preserved. This involves penalising changes to weights deemed important for old tasks Kirkpatrick3521 or enforcing weight or representational sparsity 2018selflesseq

to ensure that only a subset of neurons remain active at any point of time. The latter method has been shown to reduce the possibility of catastrophic interference across tasks

2018sr-nn, serra2018hat.

Meta-Learning-based approaches are fairly recent and have shown impressive results on small benchmarks like Omniglot and MNIST. MER riemer2018learning, inspired by GEMlopez2017gradient, utilises replay to incentivise alignment of gradients between old and new tasks. Online-aware Meta Learning (OML) javed2019meta introduces a meta-objective for a pre-training algorithm to learn an optimal representation offline, which is subsequently frozen and used for CL. 2017online-meta-learn, 2019finn-onlineml, 2018onlinelearning investigate orthogonal setups in which a learning agent uses all previously seen data to adapt quickly to an incoming stream of data, thereby ignoring the problem of catastrophic forgetting. Our motivation lies in developing a scalable, online algorithm capable of learning from limited cycles through streaming data with reduced interference on old samples. In the following sections, we review background concepts and outline our proposed algorithm. We also note interesting connections to prior work not directly pertaining to CL.

3 Preliminaries

We consider a setting where a sequence of tasks is learnt by observing their training data [] sequentially. We define = as the set of input-label pairs randomly drawn from . An any time-step during online learning, we aim to minimize the empirical risk of the model on all the tasks seen so far (), given limited access to data from previous tasks (). We refer to this objective as the cumulative risk, given by:

(1)

where is the loss on and is a learnt, possibly task-specific mapping from inputs to outputs using parameters . is the sum of all task-wise losses for tasks where goes from to . Let denote some loss objective to be minimised. Then the SGD operator acting on parameters , denoted by is defined as:

(2)

where . can be composed for updates as .

is a scalar or a vector LR.

implies gradient updates are made on data sample . We now introduce the MAML finn2017model and OML javed2019meta algorithms, that we build upon in Section 4.

Model-Agnostic Meta-Learning (MAML): Meta-learning Schmidhuber1987, or learning-to-learn learn2learn has emerged as a popular approach for training models amenable to fast adaptation on limited data. MAML finn2017model proposed optimising model parameters to learn a set of tasks while improving on auxiliary objectives like few-shot generalisation within the task distributions. We review some common terminology used in gradient-based meta-learning: 1) at a given time-step during training, model parameters (or for simplicity), are often referred to as an initialisation, since the aim is to find an ideal starting point for few-shot gradient-based adaptation on unseen data. 2) Fast or inner-updates, refer to gradient-based updates made to a copy of , optimising some inner objective (in this case, for some ). 3) A meta-update involves the trajectory of fast updates from to , followed by making a permanent gradient update (or slow-update) to . This slow-update is computed by evaluating an auxiliary objective (or meta-loss on , and differentiating through the trajectory to obtain . MAML thus optimises at time , to perform optimally on tasks in after undergoing a few gradient updates on their samples. It optimises in every meta-update, the objective:

(3)

Equivalence of Meta-Learning and CL Objectives: The approximate equivalence of first and second-order meta-learning algorithms like Reptile and MAML was shown in nichol2018first. MER riemer2018learning then showed that their CL objective of minimising loss on and aligning gradients between a set of tasks seen till any time (on the left), can be optimised by the Reptile objective (on the right), ie. :

(4)

where the meta-loss is evaluated on samples from tasks . This implies that the procedure to meta-learn an initialisation coincides with learning optimal parameters for CL.

Online-aware Meta-Learning (OML): javed2019meta proposed to meta-learn a Representation-Learning Network (RLN) to provide a representation suitable for CL to a Task-Learning Network (TLN). The RLN’s representation is learnt in an offline phase, where it is trained using catastrophic forgetting as the learning signal. Data from a fixed set of tasks (), is repeatedly used to evaluate the RLN and TLN as the TLN undergoes temporally correlated updates. In every meta-update’s inner loop, the TLN undergoes fast updates on streaming task data with a frozen RLN. The RLN and updated TLN are then evaluated through a meta-loss computed on data from along with the current task. This tests how the performance of the model has changed on in the process of trying to learn the streaming task. The meta-loss is then differentiated to get gradients for slow updates to the TLN and RLN. This composition of two losses to simulate CL in the inner loop and test forgetting in the outer loop, is referred to as the OML objective. The RLN learns to eventually provide a better representation to the TLN for CL, one which is shown to have emergent sparsity.

4 Proposed approach

In the previous section, we saw that the OML objective can directly regulate CL behaviour, and that MER exploits the approximate equivalence of meta-learning and CL objectives. We noted that OML trains a static representation offline and that MER’s algorithm is prohibitively slow. We show that optimising the OML objective online through a multi-step MAML procedure is equivalent to a more sample-efficient CL objective. In this section, we describe Continual-MAML (C-MAML), the base algorithm that we propose for online continual learning. We then detail an extension to C-MAML, referred to as Look-Ahead MAML (La-MAML), outlined in Algorithm 1.

4.1 C-Maml

C-MAML aims to optimise the OML objective online, so that learning on the current task doesn’t lead to forgetting on previously seen tasks. We define this objective, adapted to optimise a model’s parameters instead of a representation at time-step , as:

(5)

where is a stream of data tuples from the current task that is seen by the model at time . The meta-loss is evaluated on . It evaluates the fitness of for the continual learning prediction task defined in Eq. 1 until . We omit the implied data argument that is the input to each loss in for any task . We will show in Appendix B that optimising our objective in Eq. 5 through the -step MAML update in C-MAML also coincides with optimising the CL objective of AGEM chaudhry2018agem:

(6)

This differs from Eq. 4’s objective by being asymmetric: it focuses on aligning the gradients of and the average gradient of instead of aligning all the pair-wise gradients between tasks . In Appendix D, we show empirically that gradient alignment amongst old tasks doesn’t degrade while a new task is learnt, avoiding the need to repeatedly optimise the inter-task alignment between them. This results in a drastic speedup over MER’s objective (Eq. 4) which tries to align all equally, thus resampling incoming samples

to form a uniformly distributed batch over

. Since each then has -th the contribution in gradient updates, it becomes necessary for MER to take multiple passes over many such uniform batches including .

Figure 1: The proposed La-MAML algorithm: For every batch of data, the initial weights undergo a series of fast updates to obtain (here

), which is evaluated against a meta-loss to backpropagate gradients with respect to the weights

and LRs . First is updated to which is then used to update to The blue boxes indicate fast weights while the green boxes indicate gradients for the slow updates. LRs and weights are updated in an asynchronous manner.

During training, a replay-buffer is populated through reservoir sampling on the incoming data stream as in riemer2018learning. At the start of every meta-update, a batch is sampled from the current task. is also combined with a batch sampled from to form the meta-batch, , representing samples from both old and new tasks. is updated through SGD-based inner-updates by seeing the current task’s samples from one at a time. The outer-loss or meta-loss is evaluated on . It indicates the performance of parameters on all the tasks seen till time . The complete training procedure is described in Appendix C.

4.2 La-MAML

Despite the fact that meta-learning incentivises the alignment of within-task and across-task gradients, there can still be some interference between the gradients of old and new tasks, and respectively. This would lead to forgetting on , since its data is no longer fully available to us. This is especially true at the beginning of training a new task, when its gradients aren’t necessarily aligned with the old ones. A mechanism is thus needed to ensure that meta-updates are conservative with respect to , so as to avoid negative transfer on them. The magnitude and direction of the meta-update needs to be regulated, guided by how the loss on would be affected by the update.

In La-MAML, we include a set of learnable per-parameter learning rates (LRs) to be used in the inner updates, as depicted in Figure 1. This is motivated by our observation that the expression for the gradient of Eq. 5 with respect to the inner loop’s LRs directly reflects the alignment between the old and new tasks. The augmented learning objective and its gradient with respect to the LR vector , denoted as is then given as:

(7)

We provide the full derivation in the Appendix A, and simply state the expression for a first-order approximation finn2017model of here. The first term in corresponds to the gradient of the meta-loss on batch : . The second term indicates the cumulative gradient from the inner-updates: . This expression indicates that the gradient of the LRs will be negative when the inner product between and is high, ie. the two are aligned; zero when the two are orthogonal (not interfering) and positive when there is interference between the two. Negative (positive) LR gradients would pull up (down) the LR magnitude. We depict this visually in Figure 2.

Algorithm 1 La-MAML : Look-ahead MAML Input: Network weights , LRs , inner objective , meta objective , learning rate for : , Initialise Replay Buffer for  to  do      for  to  do          for batch in  do                                             for  to  do                   Push to R with reservoir sampling                                  end for                (a)                (b)                         end for      end for end for Figure 2: Different scenarios for the alignment of (blue dashed line) and , going from interference (left) to alignment (right). Yellow arrows denote the inner updates. The LR increases (decreases) when gradients align (interfere).

We propose updating the network weights and LRs asynchronously in the meta-update. Let be the updated LR vector obtained by taking an SGD step with the LR gradient from Eq. 7 at time . We then update the weights as:

(8)

where is the number of steps taken in the inner-loop. The LRs are clipped to positive values to avoid ascending the gradient, and also to avoid making interfering parameter-updates, thus mitigating catastrophic forgetting. The meta-objective thus conservatively modulates the pace and direction of learning to achieve quicker learning progress on a new task while facilitating transfer on old tasks. Algorithm 1 illustrates this procedure. Lines (a), (b) are the only difference between C-MAML and La-MAML, with C-MAML using a fixed scalar LR for the meta-update to instead of .

Our meta-learning based algorithm incorporates concepts from both prior-based and replay-based approaches. The LRs modulate the parameter updates in an data-driven manner, guided by the interplay between gradients on the replay samples and the streaming task. However, since LRs evolve with every meta-update, their decay is temporary. This is unlike many prior-based approaches, where penalties on the change in parameters gradually become so high that the network capacity saturates Kirkpatrick3521. Learnable LRs can be modulated to high and low values as tasks arrive, thus being a simpler, flexible and elegant way to constrain weights. This asynchronous update resembles trust-region optimisation trustregion since the LRs are evolved in a manner similar to look-ahead search, which adjusts step-sizes based on the loss incurred on adapted parameters. Our LR update is also analogous to the heuristic uncertainty-based LR update in UCB ebrahimi2019uncertainty, BGD bgd, which we compare to in Section 5.3.

4.3 Connections to Other Work

Stochastic Meta-Descent (SMD): When learning over a non-stationary data distribution, using decaying LR schedules is not common. Strictly diminishing LR schedules aim for closer convergence

to a fixed mimima of a stationary distribution, which is at odds with the goal of online learning. It is also not possible to manually tune these schedules since the extent of the data distribution is unknown. However,

adaptivity in LRs is still highly desired to adapt to the optimisation landscape, accelerate learning and modulate the degree of adaptation to reduce catastrophic forgetting. Our adaptive LRs can be connected to work on meta-descent hypergradient, smd in offlinesupervised learning (OSL). While several variations of meta-descent exist, the core idea behind them and our approach is gain adaptation. While we adapt the gain based on the correlation between old and new task gradients to make shared progress on all tasks, hypergradient, smd use the correlation between two successive stochastic gradients to converge faster. We rely on the meta-objective’s differentiability with respect to the LRs, to obtain LR hypergradients automatically.

Learning LRs in meta-learning: Meta-SGD li2017metasgd proposed learning the LRs in MAML for few-shot learning. Some notable differences between their update and ours exist. They synchronously update the weights and LRs while our asynchronous update to the LRs serves to carry out a more conservative update to the weights. The intuition for our update stems from the need to mitigate gradient interference and its connection to the transfer-interference trade-off ubiquitous in continual learning. -MAML alphamaml analytically updates the two scalar LRs in the MAML update for more adaptive few-shot learning. Our per-parameter LRs are modulated implicitly through back-propagation, to regulate change in parameters based on their alignment across tasks, providing our model with a more powerful degree of adaptability in the CL domain.

5 Experiments

In this section, we evaluate La-MAML in settings where the model has to learn a set of sequentially streaming classification tasks. Task-agnostic experiments, where the task identity is unknown at training and test-time, are performed on the MNIST benchmarks with a single-headed model. Task-aware experiments with known task identity, are performed on the CIFAR and TinyImagenet tinyimagenet datasets with a multi-headed model. Similar to riemer2018learning, we use the retained accuracy (RA) metric to compare various approaches. RA is the average accuracy of the model across tasks at the end of training. We also report the backward-transfer and interference (BTI) values which measure the average change in the accuracy of each task from when it was learnt to the end of the last task. A smaller BTI implies lesser forgetting during training.

Method Rotations Permutations Many
RA BTI RA BTI RA BTI
Online 53.38 1.53 -5.44 1.70 55.42 0.65 -13.76 1.19 32.62 0.43 -19.06 0.86
EWC 57.96 1.33 -20.42 1.60 62.32 1.34 -13.32 2.24 33.46 0.46 -17.84 1.15
GEM 67.38 1.75 -18.02 1.99 55.42 1.10 -24.42 1.10 32.14 0.50 -23.52 0.87
MER 77.42 0.78 -5.600.70 73.46 0.45 -9.96 0.45 47.40 0.35 -17.78 0.39
C-MAML 77.33 0.29 -7.88 0.05 74.54 0.54 -10.36 0.14 47.29 1.21 -20.86 0.95
Sync 74.07 0.58 -6.66 0.44 70.54 1.54 -14.02 2.14 44.48 0.76 -24.18 0.65
La-MAML 77.42 0.65 -8.64 0.403 74.34 0.67 -7.60 0.51 48.46 0.45 -12.96 0.073
Table 1:

RA, BTI and their standard deviation on MNIST benchmarks. Each experiment is run with 5 seeds.

Efficient Lifelong Learning (LLL): Formalized in chaudhry2018agem, the setup of efficient lifelong learning assumes that incoming data for every task has to be processed in only one single pass: once processed, data samples are not accessible anymore unless they were added to a replay memory. We evaluate our algorithm on this challenging (Single-Pass) setup as well as the standard (Multiple-Pass) setup, where ideally offline training-until-convergence is performed for every task, once we have access to the data.

Method Rotations Permutations
La-MAML 45.95 0.38 46.13 0.42
MER 218.03 6.44 227.11 12.12
Table 2:

Running times for MER and La-MAML on MNIST benchmarks for one epoch

5.1 Continual learning benchmarks

First, we carry out experiments on the toy continual learning benchmarks proposed in prior CL works. MNIST Rotations, introduced in lopez2017gradient

, comprises tasks to classify MNIST digits rotated by a different common angle in [0, 180] degrees in each task. In

MNIST Permutations, tasks are generated by shuffling the image pixels by a fixed random permutation. Unlike Rotations, the input distribution of each task is unrelated here, leading to less positive transfer between tasks. Both MNIST Permutation and MNIST Rotation have 20 tasks with 1000 samples per task. Many Permutations, a more complex version of Permutations, has five times more tasks (100 tasks) and five times less training data (200 images per task). We use the same architecture and experimental settings as in MER riemer2018learning, allowing us to compare directly with their results. We use the cross-entropy loss as the inner and outer objectives during meta-training. Similar to nichol2018first, we see improved performance when evaluating and summing the meta-loss at all steps of the inner updates as opposed to just the last one.

We compare our method in the Single-Pass setup against multiple baselines including Online, Independent, EWC Kirkpatrick3521, GEM lopez2017gradient and MER riemer2018learning (detailed in Appendix G), as well as different ablations (discussed in Section 5.3). In Table 1, we see that La-MAML achieves comparable or better performance than the baselines on all benchmarks. Table 2 shows that La-MAML matches the performance of MER in less than 20% of the training time, owing to its sample-efficient objective which allows it to make make more learning progress per iteration. This also allows us to scale it to real-world visual recognition problems as described next.

5.2 Real-world classification

While La-MAML fares well on the MNIST benchmarks, we are interested in understanding its capabilities on more complex visual classification benchmarks. We conduct experiments on the CIFAR-100 dataset in a task-incremental manner lopez2017gradient where, 20 tasks comprising of disjoint 5-way classification problems are streamed. We also evaluate on the TinyImagenet-200 dataset by partitioning its 200 classes into 40 5-way classification tasks. Experiments are carried out in both the Single-Pass and Multiple-Pass settings, where in the latter we allow all CL approaches to train up to a maximum of 10 epochs. Each method is allowed a replay-buffer, containing upto 200 and 400 samples for CIFAR-100 and TinyImagenet respectively. We provide further details about the baselines and evaluation setup in Appendix F.

Method CIFAR-100 TinyImagenet
Multiple Single Multiple Single
RA BTI RA BTI RA BTI RA BTI
IID 85.60 0.40 - - - 77.1 1.06 - - -
ER 59.70 0.75 -16.50 1.05 47.88 0.73 -12.46 0.83 48.23 1.51 -19.86 0.70 39.38 0.38 -14.33 0.89
iCARL 60.47 1.09 -15.10 1.04 53.55 1.69 -8.03 1.16 54.77 0.32 -3.93 0.55 45.79 1.49 -2.73 0.45
GEM 62.80 0.55 -17.00 0.26 48.27 1.10 -13.7 0.70 50.57 0.61 -20.50 0.10 40.56 0.79 -13.53 0.65
AGEM 58.37 0.13 -17.03 0.72 46.93 0.31 -13.4 1.44 46.38 1.34 -19.96 0.61 38.96 0.47 -13.66 1.73
MER - - 51.38 1.05 -12.83 1.44 - - 44.87 1.43 -12.53 0.58
Meta-BGD 65.09 0.77 -14.83 0.40 57.44 0.95 -10.6 0.45 * * 50.64 1.98 -6.60 1.73
C-MAML 65.44 0.99 -13.96 0.86 55.57 0.94 -9.49 0.45 61.93 1.55 -11.53 1.11 48.77 1.26 -7.6 0.52
La-ER 67.17 1.14 -12.63 0.60 56.12 0.61 -7.63 0.90 54.76 1.94 -15.43 1.36 44.75 1.96 -10.93 1.32
Sync 67.06 0.62 -13.66 0.50 58.99 1.40 -8.76 0.95 65.40 1.40 -11.93 0.55 52.84 2.55 -7.3 1.93
La-MAML 70.08 0.66 -9.36 0.47 61.18 1.44 -9.00 0.2 66.99 1.65 -9.13 0.90 52.59 1.35 -3.7 1.22

Table 3: Results on the standard continual (Multiple) and LLL (Single) setups with CIFAR-100 and TinyImagenet-200. Experiments are run with 3 seeds. * indicates result omitted due to high instability.

Table 3 reports the results of these experiments. We consistently observe superior performance of La-MAML as compared to other CL baselines on both datasets across setups. While the iCARL baseline attains lower BTI in some setups, it achieves that at the cost of much lower performance throughout learning. Among the high-performing approaches, La-MAML has the lowest BTI. Recent work 2019tinymemories, riemer2018learning noted that Experience Replay (ER) is often a very strong baseline that closely matches the performance of the proposed algorithms. We highlight the fact that meta-learning and LR modulation combined show an improvement of more than 10 and 18% (as the number of tasks increase from CIFAR to TinyImagenet) over the ER baseline in our case, with limited replay. Overall, we see that our method is robust and better-performing under both the standard and LLL setups of CL which come with different kinds of challenges. Many CL methods ebrahimi2019uncertainty, serra2018hat are suitable for only one of the two setups. Further, as explained in Figure 3, our model evolves to become resistant to forgetting as training progresses. This means that beyond a point, it can keep making gradient updates on a small window of incoming samples without needing to do meta-updates.

5.3 Evaluation of La-MAML’s learning rate modulation

To capture the gains from learning the LRs, we compare La-MAML with our base algorithm, C-MAML. We ablate our choice of updating LRs asynchronously by constructing a version of C-MAML where per-parameter learnable LRs are used in the inner updates while the meta-update still uses a constant scalar LR during training. We refer to it as Sync-La-MAML or Sync since it has synchronously updated LRs that don’t modulate the meta-update. We also construct an ablation referred to as La-ER, where the parameter updates are carried out as in ER but the LRs are modulated using the La-MAML objective’s first-order version. This tells us what the gains of LR modulation are over ER, since there is no meta-learning to encourage gradient alignment of the model parameters. While only minor gains are seen on the MNIST benchmarks from asynchronous LR modulation, the performance gap increases as the tasks get harder. On CIFAR-100 and TinyImagenet, we see a trend in the RA of our variants with La-MAML performing best followed by Sync. This shows that optimising the LRs aids learning and our asynchronous update helps in knowledge consolidation by enforcing conservative updates to mitigate interference.

To test our LR modulation against an alternative bayesian modulation scheme proposed in BGD bgd

, we define a baseline called Meta-BGD where per-parameter variances are modulated instead of LRs. This is described in further detail in Appendix

G. Meta-BGD emerges as a strong baseline and matches the performance of C-MAML given enough Monte Carlo iterations , implying

times more computation than C-MAML. Additionally, Meta-BGD was found to be sensitive to hyperparameters and required extensive tuning. We present a discussion of the robustness of our approach in Appendix

E, as well as a discussion of the setups adopted in prior work, in Appendix H.

We also compare the gradient alignment of our three variants along with ER in Table 4

by calculating the cosine similarity between the gradients of the replay samples and newly arriving data samples. As previously stated, the aim of many CL algorithms is to achieve high gradient alignment across tasks to allow parameter-sharing between them. We see that our variants achieve an order of magnitude higher cosine similarity compared to ER, verifying that our objective promotes gradient alignment.

Figure 3: Retained Accuracy (RA) for La-MAML plotted every 25 meta-updates up to Task 5 on CIFAR-100. RA at iteration j (with j increasing along the x-axis) denotes accuracy on all tasks seen uptil then. Red denotes the RA computed during the inner updates (at ). Blue denotes RA computed at right after a meta-update. We see that in the beginning, inner updates lead to catastrophic forgetting (CF) since the weights are not suitable for CL yet, but eventually become resistant when trained to retain old knowledge while learning on a stream of correlated data. We also see that RA maintains its value even as more tasks are added indicating that the model is successful at learning new tasks without sacrificing performance on old ones.
Dataset ER C-MAML SYNC La-MAML
CIFAR-100 0.0017 0.0003 0.0004 0.0027
TinyImagenet 0.0005 0.0005 0.0020 0.0023
Table 4: Gradient Alignment on CIFAR-100 and TinyImagenet dataset (values lie in [-1,1], higher is better)

6 Conclusion

We introduced La-MAML, an efficient meta-learning algorithm that leverages replay to avoid forgetting and favor positive backward transfer by learning the weights and LRs in an asynchronous manner. It is capable of learning online on a non-stationary stream of data and scales to vision tasks. We presented results that showed better performance against the state-of-the-art in the setup of efficient lifelong learning (LLL) chaudhry2018agem, as well as the standard continual learning setting. In the future, more work on analysing and producing good optimizers for CL is needed, since many of our standard go-to optimizers like Adam 2014adam are primarily aimed at ensuring faster convergence in stationary supervised learning setups. Another interesting direction is to explore how the connections to meta-descent can lead to more stable training procedures for meta-learning that can automatically adjust hyper-parameters on-the-fly based on training dynamics.

The authors are grateful to Matt Riemer, Sharath Chandra Raparthy, Alexander Zimin, Heethesh Vhavle and the anonymous reviewers for proof-reading the paper and suggesting improvements. This research was enabled in part by support provided by Compute Canada (www.computecanada.ca).

References

Appendix A Hypergradient Derivation for La-MAML

We derive the gradient of the weights and LRs at time-step under the -step MAML objective, with as the meta-loss and as the inner-objective:

Where (a) is obtained by recursively expanding and differentiating the update function as done in the step before it. (b) is obtained by assuming that the initial weight in the meta-update at time j : , is constant with respect to .

Similarly we can derive the MAML gradient for the weights , denoted as as:

 (repeatedly applying chain rule and using 

θjk=U(θjk1) ) =Lt(θjk)(Iαt(θjk1))(Iαt(θj0))( using U(θjk)=Iαt(θjk)) (\textprime implies derivative with respect to argument) =(k1k=0(Iαt(θjk)))Lt(θjk) Setting all first-order gradient terms as constants to ignore second-order derivatives, we get the first order approximation as:

In Appendix B, we show the equivalence of the C-MAML and CL objectives in Eq. 6 by showing that the gradient of the former () is equivalent to the gradient of the latter.

Appendix B Equivalence of Objectives

It is straightforward to show that when we optimise the OML objective through the -step MAML update, as proposed in C-MAML in Eq. 5:

(9)

where the inner-updates are taken using data from the streaming task , and the meta-loss is computed on the data from all tasks seen so far, it will correspond to minimising the following surrogate loss used in CL :

(10)

We show the equivalence for the case when , for higher the form gets more complicated but essentially has a similar set of terms. Reptile nichol2018first showed that the -step MAML gradient for the weights at time , denoted as is of the form:

where:

(Hessian of the meta-loss evaluated at the initial point)
(Hessian of the inner-objective evaluated at the initial point)
And, in our case:

Bias in the objective: We can see in Eq. 10 that the gradient alignment term introduces some bias, which means that the parameters don’t exactly converge to the minimiser of the losses on all tasks. This has been acceptable in the CL regime since we don’t aim to reach the minimiser of some stationary distribution anyway (as also mentioned in Section 4.3). If we did converge to the minimiser of say tasks at some time , this minimiser would no longer be optimal as soon as we see the new task . Therefore, in the limit of infinite tasks and time, ensuring low-interference between tasks will pay off much more as opposed to being able to converge to the exact minima, by allowing us to make shared progress on both previous and incoming tasks.

Appendix C C-MAML Algorithm

Algorithm 2 outlines the training procedure for the C-MAML algorithm we propose 111Our algorithm, Continual-MAML is different from a concurrent work https://arxiv.org/abs/2003.05856 which proposes an algorithm with the same name.

Input: Network weights , inner objective , meta objective , Inner learning rate , Outer learning rate
Initialise replay-buffer
for  to  do
     
     for  to  do
         for batch in  do
              
               batch of samples from for meta-loss
              for  to  do
                  Push

to R with some probability based on reservoir sampling

                   inner-update on each incoming sample
              end for
               outer-update by differentiating meta-loss
              
         end for
     end for
end for
Algorithm 2 C-MAML

Appendix D Inter-Task Alignment

We assume that at time during training, we are seeing samples from the streaming task . It is intuitive to realise that incentivising the alignment of all with the current indirectly also incentivises the alignment amongst as well. To demonstrate this, we compute the mean dot product of the gradients amongst the old tasks as the new task is added, for varying from 2 to 11. We do this for C-MAML and La-MAML on CIFAR-100.

As can be seen in Figures 3(a) and 3(b), the alignment stays positive and roughly constant even as more tasks are added.

(a) C-MAML
(b) La-MAML
Figure 4: Average dot product amongst gradients of as new tasks are added, for the C-MAML and La-MAML algorithms calculated over 5 runs. x-axis shows the streaming task ID, and y-axis shows the cosine similarity.

Appendix E Robustness

Learning rate is one of the most crucial hyper-parameters during training and it often has to be tuned extensively for each experiment. In this section we analyse the robustness of our proposed variants to their LR-related hyper-parameters on the CIFAR-100 dataset. Our three variants have different sets of these hyper-parameters which are specified as follows:

  • C-MAML: Inner and outer update LR (scalar) for the weights ( and )

  • Sync La-MAML: Inner loop initialization value for the vector LRs (), scalar learning rate of LRs () and scalar learning rate for the weights in the outer update ()

  • La-MAML: Scalar initialization value for the vector LRs () and a scalar learning rate of LRs ()

(a) C-MAML: Modulation of and
(b) Sync: Modulation of , and
(c) La-MAML: Modulation of and
Figure 5: Retained Accuracy vs Learning Rates plot for La-MAML and its variants. Figures are plotted by varying one of the learning rate hyperparameter while keeping the others fixed at their optimal value. The hyperparameter is varied between [0.001, 0.3].

La-MAML is considerably more robust to tuning compared to its variants, as can be seen in Figure 4(c). We empirically observe that it only requires tuning of the initial value of the LR, while being relatively insensitive to the learning rate of the LR (). We see a consistent trend where the increase in leads to an increase in the final accuracy of the model. This increase is very gradual, since across a wide range of LRs varying over 2 orders of magnitude (from 0.003 to 0.3), the difference in RA is only 6%. This means that even without tuning this parameter (), La-MAML would have outperformed most baselines at their optimally tuned values.

As seen in Figure 4(a), C-MAML sees considerable performance variation with the tweaking of both the inner and outer LR. We also see that the effects of the variations of the inner and outer LR follow very similar trends and their optimal values finally selected are also identical. This means that we could potentially tune them by doing just a 1D search over them together instead of varying both independently through a 2D grid search. The Sync version of La-MAML (Figure 4(b)), while being relatively insensitive to the scalar initial value and the , sees considerable performance variation as the outer learning rate for the weights: is varied. This variant has the most hyper-parameters and only exists for the purpose of ablation.

Fig. 6 shows the result of 2D grid-searches over sets of the above-mentioned hyper-parameters for C-MAML and La-MAML for a better overview.

(a) C-MAML: Modulation of and
(b) La-MAML: Modulation of and
Figure 6: Plots of Retained Accuracy (RA) across hyper-parameter variation for C-MAML and La-MAML. We show results of the grid search over the learning rate hyperparameters. RA decreases from red to blue. All the hyperparameters are varied between [0.001, 0.3], with the axes being in log-scale.

Appendix F Experimental

We carry out hyperparameter tuning for all the approaches by performing a grid-search over the range [0.0001 - 0.3] for hyper-parameters related to the learning-rate. For the multi-pass setup we use 10 epochs for all the CL approaches. In the single pass setup, all compared approaches have a hyper-parameter called glances which indicates the number of gradient updates or meta-updates made on each incoming sample of data. In the Single-Pass (LLL) setup, it becomes essential to take multiple gradient steps on each sample (or see each sample for multiple glances), since once we move on to later samples, we can’t revisit old data samples. The performance of the algorithms naturally increases with the increase in glances up to a certain point. To find the optimal number of glances to take over each sample, we search over the values [1,2,3,5,10]. Tables 5 and 6 lists the optimal hyperparameters for all the compared approaches. All setups used the SGD optimiser since it was found to preform better than Adam 2014adam (possibly due to reasons stated in Section 4.3 regarding the CL setup).

To avoid exploding gradients, we clip the gradient values of all approaches at a norm of 2.0. Class divisions across different tasks vary with the random seeds with which the experiments were run. Overall, we did not see much variability across different class splits, with the variation being within 0.5-2% of the mean reported result as can be seen from Table 3

For all our baselines, we use a constant batch-size of 10 samples from the streaming task. This batch is augmented with 10 samples from the replay buffer for the replay-based approaches. La-MAML and its variants split the batch from the streaming task into a sequence of smaller disjoint sets to take gradient steps in the inner-loop

. In MER, each sample from the incoming task is augmented with a batch of 10 replay samples to form the batch used for the meta-update. We found very small performance gaps between the first and second-order versions of our proposed variants with performance differences in the range of 1-2% for RA. This is in line with the observation that deep neural networks have near-zero hessians since the ReLU non-linearity is linear almost everywhere

zero-hessian.

Method Parameter CIFAR-100 TinyImagenet
Single Multiple Single Multiple
ER LR 0.03 0.03 0.1 0.1
Epochs/Glances 10 10 10 10
IID LR - 0.03 - 0.01
Epochs/Glances - 50 - 50
iCARL LR 0.03 0.03 0.01 0.01
Epochs/Glances 2 10 2 10
GEM LR 0.03 0.03 0.03 0.03
Epochs/Glances 2 10 2 10
AGEM LR 0.03 0.03 0.01 0.01
Epochs/Glances 2 10 2 10
MER LR 0.1 - 0.1 -
LR 0.1 - 0.1 -
LR 1 - 1 -
Epochs/Glances 10 - 10 -
Meta-BGD 50 50 50 -
std-init 0.02 0.02 0.02 -
0.1 0.1 0.1 -
mc-iters 2 2 2 -
Epochs/Glances 3 10 3 -
C-MAML 0.03 0.03 0.03 0.03
0.03 0.03 0.03 0.03
Epochs/Glances 5 10 2 10
La-ER 0.1 0.1 0.03 0.03
0.1 0.1 0.1 0.1
Epochs/Glances 1 10 2 10
Sync La-MAML 0.1 0.1 0.075 0.075
0.1 0.1 0.075 0.075
0.3 0.3 0.25 0.25
Epochs/Glances 5 10 2 10
La-MAML 0.1 0.1 0.1 0.1
0.3 0.3 0.3 0.3
Epochs/Glances 10 10 2 10




Table 5: Final hyperparameters for all compared approaches on the CIFAR and TinyImagenet benchmarks
Method Parameter Permutations Rotations Many
C-MAML 0.03 0.1 0.03
0.1 0.1 0.15
Glances 5 5 5

Sync La-MAML
0.15 0.15 0.03
0.1 0.3 0.03
0.1 0.1 0.1
Glances 5 5 10

La-MAML
0.3 0.3 0.1
0.15 0.15 0.1
Glances 5 5 10

Table 6: Final hyperparameters used for our variants on the MNIST benchmarks

MNIST Benchmarks: On the MNIST continual learning benchmarks, images of size 28x28 are flattened to create a 1x784 array. This array is passed on to a fully-connected neural network having two layers with 100 nodes each. Each layer uses ReLU non-linearity. The output layer uses a single head with 10 nodes corresponding to the 10 classes. In all our experiments, we use a modest replay buffer of size 200 for MNIST Rotations and Permutation and size 500 for Many Permutations.

Real-world visual classification: For CIFAR and Tiny we used a CNN having 3 and 4 conv layers respectively with 160 3x3 filters. The output from the final convolution layer is flattened and is passed through 2 fully connected layers having 320 and 640 units respectively. All the layers are succeeded by ReLU nonlinearity. Finally, a multi-headed output layer is used for performing 5-way classification for every task. For CIFAR and TinyImagenet, we allow a replay buffer of size 200 and 400 respectively which implies that each class in these dataset gets roughly about 1-2 samples in the buffer. For TinyImagenet, we split the validation set into val and test splits, since the labels in the actual test set are not released.

Appendix G Baselines

On the MNIST benchmarks, we compare our algorithm against the baselines used in riemer2018learning, which are as follows:

  • Online: A baseline for the LLL setup, where a single network is trained one example at a time with SGD.

  • EWC Kirkpatrick3521: Elastic Weight Consolidation is a regularisation based method which constraints the weights important for the previous tasks to avoid catastrophic forgetting.

  • GEM lopez2017gradient: Gradient Episodic Memory does constrained optimisation by solving a quadratic program on the gradients of new and replay samples, trying to make sure that these gradients do not alter the past tasks’ knowledge.

  • MER riemer2018learning: Meta Experience Replay samples i.i.d data from a replay memory to meta-learn model parameters that show increased gradient alignment between old and current samples. We evaluate against this baseline only in the LLL setups.

On the real-world visual classification dataset, we carry out experiments on GEM, MER along with:-

  • IID: Network gets the data from all tasks in an independent and identically distributed manner, thus bypassing the issue of catastrophic forgetting completely. Therefore, IID acts as an upper bound for the RA achievable with this network.

  • ER: Experience Replay uses a small replay buffer to store old data using reservoir sampling. This stored data is then replayed again along with the new data samples.

  • iCARL rebuffi2017icarl: iCARL is originally from the family of class incremental learners, which learns to classify images in the metric space. It prevents catastrophic forgetting by using a memory of exemplar samples to perform distillation from the old network weights. Since we perform experiments in a task incremental setting, we use the modified version of iCARL (as used by GEM lopez2017gradient

    ), where distillation loss is calculated only over the logits of the particular task.

  • A-GEM chaudhry2018agem: Averaged Gradient Episodic Memory proposed to project gradients of the new task to a direction such as to avoid interference with respect to the average gradient of the old samples in the buffer.

  • Meta-BGD: Bayesian Gradient Descent bgd proposes training a bayesian neural network for CL where the learning rate for the parameters (the means) are derived from their variances. We construct this baseline by equipping C-MAML with bayesian training, where each parameter in

    is now sampled from a gaussian distribution with a certain mean and variance. The inner-loop stays same as C-MAML(constant LR), but the magnitude of the meta-update to the parameters in

    is now influenced by their associated variances. The variance updates themselves have a closed form expression which depends on monte-carlo samples of the meta-loss, thus implying forward passes of the inner-and-outer loops (each time with a newly sampled ) to get meta-gradients.

Appendix H Discussion on Prior Work

In Table 7, we provide a comparative overview of various continual learning methods to situate our work better in the context of prior work.

Prior-focused methods face model capacity saturation as the number of tasks increase. These methods freeze weights to defy forgetting, and so penalise changes to the weights, even if those changes could potentially improve model performance on old tasks. They are also not suitable for the LLL setup (section 5), since it requires many passes through the data for every task to learn weights that are optimal enough to be frozen. Additionally, the success of weight freezing schemes can be attributed to over-parameterisation in neural networks, leading to sub-networks with sufficient capacity to learn separate tasks. However continual-learning setups are often motivated in resource-constrained settings requiring efficiency and scalability. Therefore solutions that allow light-weight continual learners are desirable. Meta-learning algorithms are able to exploit even small models to learn a good initialization where gradients are aligned across tasks, enabling shared progress on optimisation of task-wise objectives. Our method additionally allows meta-learning to also achieve a prior-focusing affect through the async-meta-update, without necessarily needing over-parameterised models.

In terms of resources, meta-learning based methods require smaller replay memories than traditional methods because they learn to generalise better across and within tasks, thus being sample-efficient. Our learnable learning rates incur a memory overhead equal to the parameters of the network. This is comparable to or less than many prior-based methods that store between 1 to scalars per parameter depending on the approach ( is the number of tasks).

It should be noted that our learning rate modulation involves clipping updates for parameters with non-aligning gradients. In this aspect, it is related to methods like GEM and AGEM mentioned before. Where the distinction lies, is that our method takes some of the burden off of the clipping, by ensuring that gradients are more aligned in the first place. This means that there should be less interference and therefore less clipping of updates deemed essential for learning new tasks, on the whole.

Approach Transfer Capacity Resources Algorithm
FWT BWT Saturates LLL Storage
Prior-Focused - T Fishers EWC Kirkpatrick3521
Prior Focused - T masks HAT serra2018hat
Prior Focused - 2x params BGD/UCB bgd ebrahimi2019uncertainty



Replay
- - replay iCARL rebuffi2017icarl
Replay - - replay GEM lopez2017gradient


Meta + Replay
replay MER riemer2018learning
Meta + Replay replay Ours

Table 7: Setups in prior work: We describe the setups and assumptions adopted by prior work, focusing on approaches relevant to our method. FWT and BWT refer to forward and backward transfer as defined in lopez2017gradient. ’-’ refers to no inductive bias for or against the specific property. Saturation of capacity refers to reduced network plasticity due to weight change penalties gradually making further learning impossible. The LLL setup is defined in Section 5. and under replay indicate that a method’s replay requirements are lesser or more compared to other methods in the table. Fishers

refers to the Fisher Information Matrix (FIM) computed per task. Each FIM has storage equal to that of the model parameters. Approaches using Bayesian Neural Networks require twice as many parameters (as does La-MAML) to store the mean and variance estimates per parameter.