supervised-reptile
Reptile on supervised meta-learning datasets
view repo
This paper considers metalearning problems, where there is a distribution of tasks, and we would like to obtain an agent that performs well (i.e., learns quickly) when presented with a previously unseen task sampled from this distribution. We present a remarkably simple metalearning algorithm called Reptile, which learns a parameter initialization that can be fine-tuned quickly on a new task. Reptile works by repeatedly sampling a task, training on it, and moving the initialization towards the trained weights on that task. Unlike MAML, which also learns an initialization, Reptile doesn't require differentiating through the optimization process, making it more suitable for optimization problems where many update steps are required. We show that Reptile performs well on some well-established benchmarks for few-shot classification. We provide some theoretical analysis aimed at understanding why Reptile works.
READ FULL TEXT VIEW PDF
This paper presents a novel optimization method for maximizing generaliz...
read it
Model-agnostic meta-learning (MAML) effectively meta-learns an initializ...
read it
Non-convex optimization problems are challenging to solve; the success a...
read it
Model agnostic meta-learning (MAML) is a popular state-of-the-art
meta-l...
read it
Initializing the weights and the biases is a key part of the training pr...
read it
Unsupervised domain translation has recently achieved impressive perform...
read it
Few-shot dataset generalization is a challenging variant of the well-stu...
read it
Reptile on supervised meta-learning datasets
A PyTorch implementation of OpenAI's REPTILE algorithm
PyTorch implementation of OpenAI's REPTILE Algorithm
Implementation of MAML and Reptile algorithms with a JS demo on the sine regression toy experiment
While machine learning systems have surpassed humans at many tasks, they generally need far more data to reach the same level of performance. For example, Schmidt et al.
[17, 15] showed that human subjects can recognize new object categories based on a few example images. Lake et al. [12] noted that on the Atari game of Frostbite, human novices were able to make significant progress on the game after 15 minutes, but double-dueling-DQN [19] required more than 1000 times more experience to attain the same score.It is not completely fair to compare humans to algorithms learning from scratch, since humans enter the task with a large amount of prior knowledge, encoded in their brains and DNA. Rather than learning from scratch, they are fine-tuning and recombining a set of pre-existing skills. The work cited above, by Tenenbaum and collaborators, argues that humans’ fast-learning abilities can be explained as Bayesian inference, and that the key to developing algorithms with human-level learning speed is to make our algorithms more Bayesian. However, in practice, it is challenging to develop (from first principles) Bayesian machine learning algorithms that make use of deep neural networks and are computationally feasible.
Meta-learning has emerged recently as an approach for learning from small amounts of data. Rather than trying to emulate Bayesian inference (which may be computationally intractable), meta-learning seeks to directly optimize a fast-learning algorithm, using a dataset of tasks. Specifically, we assume access to a distribution over tasks, where each task is, for example, a classification problem. From this distribution, we sample a training set and a test set of tasks. Our algorithm is fed the training set, and it must produce an agent that has good average performance on the test set. Since each task corresponds to a learning problem, performing well on a task corresponds to learning quickly.
A variety of different approaches to meta-learning have been proposed, each with its own pros and cons. In one approach, the learning algorithm is encoded in the weights of a recurrent network, but gradient descent is not performed at test time. This approach was proposed by Hochreiter et al. [8] who used LSTMs for next-step prediction and has been followed up by a burst of recent work, for example, Santoro et al. [16] on few-shot classification, and Duan et al. [3] for the POMDP setting.
A second approach is to learn the initialization of a network, which is then fine-tuned at test time on the new task. A classic example of this approach is pretraining using a large dataset (such as ImageNet
[2]) and fine-tuning on a smaller dataset (such as a dataset of different species of bird [20]). However, this classic pre-training approach has no guarantee of learning an initialization that is good for fine-tuning, and ad-hoc tricks are required for good performance. More recently, Finn et al. [4] proposed an algorithm called MAML, which directly optimizes performance with respect to this initialization—differentiating through the fine-tuning process. In this approach, the learner falls back on a sensible gradient-based learning algorithm even when it receives out-of-sample data, thus allowing it to generalize better than the RNN-based approaches [5]. On the other hand, since MAML needs to differentiate through the optimization process, it’s not a good match for problems where we need to perform a large number of gradient steps at test time. The authors also proposed a variant called first-order MAML (FOMAML), which is defined by ignoring the second derivative terms, avoiding this problem but at the expense of losing some gradient information. Surprisingly, though, they found that FOMAML worked nearly as well as MAML on the Mini-ImageNet dataset [18]. (This result was foreshadowed by prior work in meta-learning [1, 13] that ignored second derivatives when differentiating through gradient descent, without ill effect.) In this work, we expand on that insight and explore the potential of meta-learning algorithms based on first-order gradient information, motivated by the potential applicability to problems where it’s too cumbersome to apply techniques that rely on higher-order gradients (like full MAML).We make the following contributions:
We point out that first-order MAML [4] is simpler to implement than was widely recognized prior to this article.
We introduce Reptile, an algorithm closely related to FOMAML, which is equally simple to implement. Reptile is so similar to joint training (i.e., training to minimize loss on the expecation over training tasks) that it is especially surprising that it works as a meta-learning algorithm. Unlike FOMAML, Reptile doesn’t need a training-test split for each task, which may make it a more natural choice in certain settings. It is also related to the older idea of fast weights / slow weights [7].
We provide a theoretical analysis that applies to both first-order MAML and Reptile, showing that they both optimize for within-task generalization.
We consider the optimization problem of MAML [4]: find an initial set of parameters, , such that for a randomly sampled task with corresponding loss , the learner will have low loss after updates. That is:
(1) |
where is the operator that updates times using data sampled from . In few-shot learning, corresponds to performing gradient descent or Adam [10] on batches of data sampled from .
MAML solves a version of Equation 1 that makes on additional assumption: for a given task , the inner-loop optimization uses training samples , whereas the loss is computed using test samples . This way, MAML optimizes for generalization, akin to cross-validation. Omitting the superscript , we notate this as
(2) |
MAML works by optimizing this loss through stochastic gradient descent, i.e., computing
(3) | ||||
(4) |
In Equation 4, is the Jacobian matrix of the update operation .
corresponds to adding a sequence of gradient vectors to the initial vector, i.e.,
. (In Adam, the gradients are also rescaled elementwise, but that does not change the conclusions.) First-order MAML (FOMAML) treats these gradients as constants, thus, it replaces Jacobian by the identity operation. Hence, the gradient used by FOMAML in the outer-loop optimization is . Therefore, FOMAML can be implemented in a particularly simple way: (1) sample task ; (2) apply the update operator, yielding ; (3) compute the gradient at , ; and finally (4) plug into the outer-loop optimizer.In this section, we describe a new first-order gradient-based meta-learning algorithm called Reptile. Like MAML, Reptile learns an initialization for the parameters of a neural network model, such that when we optimize these parameters at test time, learning is fast—i.e., the model generalizes from a small number of examples from the test task. The Reptile algorithm is as follows:
In the last step, instead of simply updating in the direction , we can treat as a gradient and plug it into an adaptive algorithm such as Adam [10]. (Actually, as we will discuss in Section 5.1, it is most natural to define the Reptile gradient as , where is the stepsize used by the SGD operation.) We can also define a parallel or batch version of the algorithm that evaluates on tasks each iteration and updates the initialization to
(5) |
where ; the updated parameters on the task.
This algorithm looks remarkably similar to joint training on the expected loss . Indeed, if we define to be a single step of gradient descent (), then this algorithm corresponds to stochastic gradient descent on the expected loss:
(6) | ||||
(7) |
However, if we perform multiple gradient updates in the partial minimization , then the expected update does not correspond to taking a gradient step on the expected loss . Instead, the update includes important terms coming from second-and-higher derivatives of , as we will analyze in Section 5.1. Hence, Reptile converges to a solution that’s very different from the minimizer of the expected loss .
Other than the stepsize parameter and task sampling, the batched version of Reptile is the same as the SimuParallelSGD algorithm [21]. SimuParallelSGD is a method for communication-efficient distributed optimization, where workers perform gradient updates locally and infrequently average their parameters, rather than the standard approach of averaging gradients.
As a simple case study, let’s consider the 1D sine wave regression problem, which is slightly modified from Finn et al. [4]. This problem is instructive since by design, joint training can’t learn a very useful initialization; however, meta-learning methods can.
The task is defined by the amplitude and phase of a sine wave function . The task distribution by sampling and .
Sample points
Learner sees and predicts the whole function
Loss is error on the whole interval
(8) |
We calculate this integral using equally-spaced points .
First note that the average function is zero everywhere, i.e., , due to the random phase . Therefore, it is useless to train on the expected loss , as this loss is minimized by the zero function .
On the other hand, MAML and Reptile give us an initialization that outputs approximately before training on a task , but the internal feature representations of the network are such that after training on the sampled datapoints , it closely approximates the target function . This learning progress is shown in the figures below. Figure 1 shows that after Reptile training, the network can quickly converge to a sampled sine wave and infer the values away from the sampled points. As points of comparison, we also show the behaviors of MAML and a randomly-initialized network on the same task.
In this section, we provide two alternative explanations of why Reptile works.
Here, we will use a Taylor series expansion to approximate the update performed by Reptile and MAML. We will show that both algorithms contain the same leading-order terms: the first term minimizes the expected loss (joint training), the second and more interesting term maximizes within-task generalization. Specifically, it maximizes the inner product between the gradients on different minibatches from the same task. If gradients from different batches have positive inner product, then taking a gradient step on one batch improves performance on the other batch.
Unlike in the discussion and analysis of MAML, we won’t consider a training set and test set from each task; instead, we’ll just assume that each task gives us a sequence of loss functions ; for example, classification loss on different minibatches. We will use the following definitions:
(9) | |||||
(10) | |||||
(11) | |||||
(12) |
For each of these definitions, .
First, let’s calculate the SGD gradients to as follows.
(13) | ||||
(14) | ||||
(15) | ||||
(16) |
Next, we will approximate the MAML gradient. Define as the operator that updates the parameter vector on minibatch : .
(17) | ||||
(18) | ||||
(19) | ||||
(20) | ||||
(21) |
Next, let’s expand to leading order
(22) | ||||
(23) | ||||
(24) |
For simplicity of exposition, let’s consider the case, and later we’ll provide the general formulas.
(25) | |||||
(26) | |||||
(27) |
As we will show in the next paragraph, the terms like serve to maximize the inner products between the gradients computed on different minibatches, while lone gradient terms like take us to the minimum of the joint training problem.
When we take the expectation of , , and under minibatch sampling, we are left with only two kinds of terms which we will call and . In the equations below means that we are taking the expectation over the task and the two minibatches defining and , respectively.
is defined as gradient of expected loss.
(28) |
is the direction that brings towards the minimum of the “joint training” problem; the expected loss over tasks.
The more interesting term is , defined as follows:
(29) | |||||
(30) | |||||
(31) | |||||
(32) |
Thus, is the direction that increases the inner product between gradients of different minibatches for a given task, improving generalization.
Recalling our gradient expressions, we get the following expressions for the meta-gradients, for SGD with :
(33) | ||||
(34) | ||||
(35) |
In practice, all three gradient expressions first bring us towards the minimum of the expected loss over tasks, then the higher-order term enables fast learning by maximizing the inner product between gradients within a given task.
Finally, we can extend these calculations to the general case:
(36) | ||||
(37) | ||||
(38) | ||||
(39) | ||||
(40) | ||||
(41) |
As in the , the ratio of coefficients of the term and the term goes . However, in all cases, this ratio increases linearly with both the stepsize and the number of iterations . Note that the Taylor series approximation only holds for small .
Here, we argue that Reptile converges towards a solution that is close (in Euclidean distance) to each task ’s manifold of optimal solutions. This is a informal argument and should be taken much less seriously than the preceding Taylor series analysis.
Let denote the network initialization, and let denote the set of optimal parameters for task . We want to find such that the distance is small for all tasks.
(42) |
We will show that Reptile corresponds to performing SGD on that objective.
Given a non-pathological set , then for almost all points the gradient of the squared distance is , where is the projection (closest point) of onto . Thus,
(43) | ||||
(44) |
Each iteration of Reptile corresponds to sampling a task and performing a stochastic gradient update
(45) | ||||
(46) | ||||
(47) |
In practice, we can’t exactly compute , which is defined as a minimizer of . However, we can partially minimize this loss using gradient descent. Hence, in Reptile we replace by the result of running steps of gradient descent on starting with initialization .
We evaluate our method on two popular few-shot classification tasks: Omniglot [11] and Mini-ImageNet [18]. These datasets make it easy to compare our method to other few-shot learning approaches like MAML.
In few-shot classification tasks, we have a meta-dataset containing many classes , where each class is itself a set of example instances . If we are doing -shot, -way classification, then we sample tasks by selecting classes from and then selecting
examples for each class. We split these examples into a training set and a test set, where the test set contains a single example for each class. The model gets to see the entire training set, and then it must classify a randomly chosen sample from the test set. For example, if you trained a model for 5-shot, 5-way classification, then you would show it 25 examples (5 per class) and ask it to classify a 26
^{th} example.In addition to the above setup, we also experimented with the transductive
setting, where the model classifies the entire test set at once. In our transductive experiments, information was shared between the test samples via batch normalization
[9]. In our non-transductive experiments, batch normalization statistics were computed using all of the training samples and a single test sample. We note that Finn et al. [4] use transduction for evaluating MAML.For our experiments, we used the same CNN architectures and data preprocessing as Finn et al. [4]. We used the Adam optimizer [10] in the inner loop, and vanilla SGD in the outer loop, throughout our experiments. For Adam we set because we found that momentum reduced performance across the board.^{1}^{1}1This finding also matches our analysis from Section 5.1, which suggests that Reptile works because sequential steps come from different mini-batches. With momentum, a mini-batch has influence over the next few steps, reducing this effect.
During training, we never reset or interpolated Adam’s rolling moment data; instead, we let it update automatically at every inner-loop training step. However, we did backup and reset the Adam statistics when evaluating on the test set to avoid information leakage.
The results on Omniglot and Mini-ImageNet are shown in Tables 2 and 1. While MAML, FOMAML, and Reptile have very similar performance on all of these tasks, Reptile does slightly better than the alternatives on Mini-ImageNet and slightly worse on Omniglot. It also seems that transduction gives a performance boost in all cases, suggesting that further research should pay close attention to its use of batch normalization during testing.
Algorithm | 1-shot 5-way | 5-shot 5-way |
---|---|---|
MAML + Transduction | ||
-order MAML + Transduction | ||
Reptile | ||
Reptile + Transduction |
For this experiment, we used four non-overlapping mini-batches in each inner-loop, yielding gradients , , , and . We then compared learning performance when using different linear combinations of the ’s for the outer loop update. Note that two-step Reptile corresponds to , and two-step FOMAML corresponds to .
To make it easier to get an apples-to-apples comparison between different linear combinations, we simplified our experimental setup in several ways. First, we used vanilla SGD in the inner- and outer-loops. Second, we did not use meta-batches. Third, we restricted our experiments to 5-shot, 5-way Omniglot. With these simplifications, we did not have to worry as much about the effects of hyper-parameters or optimizers.
Figure 3 shows the learning curves for various inner-loop gradient combinations. For gradient combinations with more than one term, we ran both a sum and an average of the inner gradients to correct for the effective step size increase.
As expected, using only the first gradient is quite ineffective, since it amounts to optimizing the expected loss over all tasks. Surprisingly, two-step Reptile is noticeably worse than two-step FOMAML, which might be explained by the fact that two-step Reptile puts less weight on relative to (Equations 35 and 34). Most importantly, though, all the methods improve as the number of mini-batches increases. This improvement is more significant when using a sum of all gradients (Reptile) rather than using just the final gradient (FOMAML). This also suggests that Reptile can benefit from taking many inner loop steps, which is consistent with the optimal hyper-parameters found for Section 6.1.
Both Reptile and FOMAML use stochastic optimization in their inner-loops. Small changes to this optimization procedure can lead to large changes in final performance. This section explores the sensitivity of Reptile and FOMAML to the inner loop hyperparameters, and also shows that FOMAML’s performance significantly drops if mini-batches are selected the wrong way.
The experiments in this section look at the difference between shared-tail FOMAML, where the final inner-loop mini-batch comes from the same set of data as the earlier inner-loop batches, to separate-tail FOMAML, where the final mini-batch comes from a disjoint set of data. Viewing FOMAML as an approximation to MAML, separate-tail FOMAML can be seen as the more correct approach (and was used by Finn et al. [4]), since the training-time optimization resembles the test-time optimization (where the test set doesn’t overlap with the training set). Indeed, we find that separate-tail FOMAML is significantly better than shared-tail FOMAML. As we will show, shared-tail FOMAML degrades in performance when the data used to compute the meta-gradient () overlaps significantly with the earlier batches; however, Reptile and separate-tail MAML maintain performance and are not very sensitive to the inner-loop hyperparameters.
Figure 3(a) shows that when minibatches are selected by cycling through the training data (shared-tail, cycle), shared-tail FOMAML performs well up to four inner-loop iterations, but drops in performance starting at five iterations, where the final minibatch (used to compute ) overlaps with the earlier ones. When we use random sampling instead (shared-tail, replacement), shared-tail FOMAML degrades more gradually. We hypothesize that this is because some samples still appear in the final batch that were not in the previous batches. The effect is stochastic, so it makes sense that the curve is smoother.
Figure 3(b) shows a similar phenomenon, but here we fixed the inner-loop to four iterations and instead varied the batch size. For batch sizes greater than 25, the final inner-loop batch for shared-tail FOMAML necessarily contains samples from the previous batches. Similar to Figure 3(a), here we observe that shared-tail FOMAML with random sampling degrades more gradually than shared-tail FOMAML with cycling.
In both of these parameter sweeps, separate-tail FOMAML and Reptile do not degrade in performance as the number of inner-loop iterations or batch size changes.
There are several possible explanations for above findings. For example, one might hypothesize that shared-tail FOMAML is only worse in these experiments because its effective step size is much lower than that of separate-tail FOMAML. However, Figure 3(c) suggests that this is not the case: performance was equally poor for every choice of step size in a thorough sweep. A different hypothesis is that shared-tail FOMAML performs poorly because, after a few inner-loop steps on a sample, the gradient of the loss for that sample does not contain very much useful information about the sample. In other words, the first few SGD steps might bring the model close to a local optimum, and then further SGD steps might simply bounce around this local optimum.
Meta-learning algorithms that perform gradient descent at test time are appealing because of their simplicity and generalization properties [5]. The effectiveness of fine-tuning (e.g. from models trained on ImageNet [2]) gives us additional faith in these approaches. This paper proposed a new algorithm called Reptile, whose training process is only subtlely different from joint training and only uses first-order gradient information (like first-order MAML).
We gave two theoretical explanations for why Reptile works. First, by approximating the update with a Taylor series, we showed that SGD automatically gives us the same kind of second-order term that MAML computes. This term adjusts the initial weights to maximize the dot product between the gradients of different minibatches on the same task—i.e., it encourages the gradients to generalize between minibatches of the same task. We also provided a second informal argument, which is that Reptile finds a point that is close (in Euclidean distance) to all of the optimal solution manifolds of the training tasks.
While this paper studies the meta-learning setting, the Taylor series analysis in Section 5.1 may have some bearing on stochastic gradient descent in general. It suggests that when doing stochastic gradient descent, we are automatically performing a MAML-like update that maximizes the generalization between different minibatches. This observation partly explains why fine tuning (e.g., from ImageNet to a smaller dataset [20]) works well. This hypothesis would suggest that joint training plus fine tuning will continue to be a strong baseline for meta-learning in various machine learning problems.
We see several promising directions for future work:
Understanding to what extent SGD automatically optimizes for generalization, and whether this effect can be amplified in the non-meta-learning setting.
Applying Reptile in the reinforcement learning setting. So far, we have obtained negative results, since joint training is a strong baseline, so some modifications to Reptile might be necessary.
Exploring whether Reptile’s few-shot learning performance can be improved by deeper architectures for the classifier.
Exploring whether regularization can improve few-shot learning performance, as currently there is a large gap between training and testing error.
Evaluating Reptile on the task of few-shot density modeling [14].
Towards a new evolutionary computation
, pages 75–102. Springer, 2006.Proceedings of ICML Workshop on Unsupervised and Transfer Learning
, pages 195–206, 2012.For all experiments, we linearly annealed the outer step size to 0. We ran each experiment with three different random seeds, and computed the confidence intervals using the standard deviation across the runs.
Initially, we tried optimizing the Reptile hyper-parameters using CMA-ES [6]. However, we found that most hyper-parameters had little effect on the resulting performance. After seeing this result, we simplified all of the hyper-parameters and shared hyper-parameters between experiments when it made sense.
Parameter | 5-way | 20-way |
Adam learning rate | 0.001 | 0.0005 |
Inner batch size | 10 | 20 |
Inner iterations | 5 | 10 |
Training shots | 10 | 10 |
Outer step size | 1.0 | 1.0 |
Outer iterations | 100K | 200K |
Meta-batch size | 5 | 5 |
Eval. inner iterations | 50 | 50 |
Eval. inner batch | 5 | 10 |
Parameter | 1-shot | 5-shot |
Adam learning rate | ||
Inner batch size | 10 | 10 |
Inner iterations | 8 | 8 |
Training shots | 15 | 15 |
Outer step size | 1.0 | 1.0 |
Outer iterations | 100K | 100K |
Meta-batch size | 5 | 5 |
Eval. inner batch size | 5 | 15 |
Eval. inner iterations | 50 | 50 |
Parameter | Value |
---|---|
Inner learning rate | |
Inner batch size | 25 |
Outer step size | 0.25 |
Outer iterations | 40K |
Eval. inner batch size | 25 |
Eval. inner iterations | 5 |
Parameter | Figure 3(b) | Figure 3(a) | Figure 3(c) |
Inner learning rate | |||
Inner batch size | - | 25 | 100 |
Inner iterations | 4 | - | 4 |
Outer step size | 1.0 | 1.0 | - |
Outer iterations | 40K | 40K | 40K |
Eval. inner batch size | 25 | 25 | 25 |
Eval. inner iterations | 5 | 5 | 5 |
Comments
There are no comments yet.