Gradient-EM Bayesian Meta-learning

06/21/2020 ∙ by Yayi Zou, et al. ∙ Columbia University cornell university 0

Bayesian meta-learning enables robust and fast adaptation to new tasks with uncertainty assessment. The key idea behind Bayesian meta-learning is empirical Bayes inference of hierarchical model. In this work, we extend this framework to include a variety of existing methods, before proposing our variant based on gradient-EM algorithm. Our method improves computational efficiency by avoiding back-propagation computation in the meta-update step, which is exhausting for deep neural networks. Furthermore, it provides flexibility to the inner-update optimization procedure by decoupling it from meta-update. Experiments on sinusoidal regression, few-shot image classification, and policy-based reinforcement learning show that our method not only achieves better accuracy with less computation cost, but is also more robust to uncertainty.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Meta-learning, also known as learning to learn, has gained tremendous attention in both academia and industry, especially with applications to few-shot learningfinn2017model. The nature of multi-task setting in meta-learning is that these tasks share similarities, such that learning from sufficiently many tasks helps mastering new tasks faster. This feature is referred to as fast adaptation.

The early fast meta-learning algorithm was gradient-based and deterministic, which may cause overfitting on both inner-level and meta-level mishra2017meta. With growing interests in prediction uncertainty evaluation and overfitting control, later studies explored probabilistic meta-learning methods grant2018recasting; yoon2018bayesian; finn2018probabilistic

. It has been agreed that Bayesian inference is one of the most convenient choices because of its Occam’s Razor property

mackay2003information that automatically prevents overfitting, which happens in deep neural network (DNN) very often. It also provides reliable predictive uncertainty because of its probabilistic nature. This makes Bayesian methods important to DNN, which as guo2017calibration shows, unlike shallow neural networks, are usually poorly calibrated on predictive uncertainty.

The theoretical foundation of Bayesian meta-learning is hierarchical Bayes (HB) good1980some or empirical Bayes (EB) robbins1985stochastic

which restricts the learning of meta-parameters to point estimates. For simplicity we focus on EB in this paper, we can always extend to HB by adding a hyper-prior to the learning of meta-parameters like

ravi2018amortized. A common solution of EB is a bi-level iterative optimization procedure ravi2018amortized; eb, where the “inner-update” refers to adaptation to a given task, and the “meta-update” is the meta-training objective. We extend the original optimization framework for train/val split in the inner-update procedure to mitigate in-task overfitting which is important for NN based ML. We also hypothesis a mechanism of how EB framework achieves fast-adaptation(few inner-update gradient steps) under Gaussian parameterization, along with empirical evidences. What’s more, we successfully adapt this EB framework to RL both theoretically and empirically which has not been done before.

We show that many important previous works in (Bayesian)meta-learning ravi2018amortized; finn2018probabilistic; yoon2018bayesian; finn2017model; nichol2018first

can be included to this extended framework. However, in previous works, the meta-update step requires backpropagation through the inner optimization process

rajeswaran2019meta which imposes large computation and memory burden as the increase of inner-update gradient steps. This puts limits on possible applications, especially those require many inner-update gradient steps or involves large dataset(See Appendix for our examples). Motivated by the above observations, we propose a gradient-based Bayesian algorithm inspired by Gradient-EM algorithm. By designing a new way to compute gradient of meta loss in Bayesian MAML by utilizing the gradient information of prior distribution, we come up with an algorithm that decouples meta-update and inner-update and thus avoids the computation and memory burden of previous methods, making it scalable to a large number of inner-update gradient steps. In addition, it enables large flexibility on the choice of inner-update optimization method because it only requires the value of the result of the inner-update optimization, instead of the optimization process (for example in experiments we use Adam in classifications and Trust Region Policy Optimization in RL). The separability of meta-update and inner-update also makes it a potentially useful scheme for distributed learning and private learning.

In experiments, we show our method can quickly learn the knowledge and uncertainty of a novel task with less computation burden in sinusoidal regression, image classifications benchmarks and reinforcement learning.

2 Problem Formulation and Framework

2.1 General Meta Learning Setting

We set up the K-shot meta-learning framework upon reinforcement learning(RL) with episode length as in finn2017model

, where supervised learning is a special case with

. With a decision rule(policy) we can sample rollout data from the task environment. A decision rule (policy) can be evaluated on

with loss function

. We assume each task to be i.i.d. sampled from the task space , following some task distribution . During meta-training phase, we collect samples rollout of current policy and another samples rollout after 1 policy gradient training of ( is not needed in generating samples in supervised learning). We denote this samples as . At meta-testing phase, for a randomly sampled task , samples are first provided as . We are then required to return based on to evaluate its expected loss on more samples generated from that task. The objective is to come up with an algorithm that produces a decision rule that minimize the expected test loss over all tasks .




Figure 1: (a) Minimal (b) Graphical Model

2.2 Extended Empirical Bayes Meta-Learning Framework

We consider parameterized decision rule and construct a corresponding generative model : (We leave the detail of this construction in RL to Appendix). For each task , denote the best policy parameter as well as the best fitted underlying generative model parameter to be . In general such maximum is not unique, which is discussed in Section 2.3. With uniquely defined, we have a distribution induced by (change of variable).(See Figure 1(b) for a summary of the generative model) Under perfect approximation, the ground-truth generator matches our generative model: , resulting in the following proposition:

Theorem 1.

Suppose data generator is represented by the hierarchical model and , and define for distribution over . Let be independent samples from task , and consider determined by via . Then


Two observations are made here. First, this theorem guarantees best decision rule we can come up with during meta-testing: evaluating posterior using prior . Second, this theorem suggests an estimation method for during meta-training:

. We prove in Appendix that this estimator is not only asymptotic consistent but also with good asymptotic normality which means it quickly converge to true value with small variance as number of tasks increases. We further parameterize

by (in hierarchical modeling this is called hyper-prior), and introduce short notation , then the optimization in meta-training can be written as . This is an extension of the popular MLE approach in empirical Bayes, where (marginal) log-likelihood is the special case of . For clarity we denote and , . There is a bias/variance trade-off between and . Using as the meta loss function improves in-task overfitting problem while extracts more information from the data. A detailed discussion is presented in Appendix.

An stochastic gradient descent (SGD) approach to meta-training is provided in Algorithm 

1: at iteration , gradient for each task in the -th meta-training batch is computed by subroutine Meta-Gradient, then gradient ascent is performed. A variational inference (VI) approach to meta-testing is shown in Algorithm 2, where posterior is estimated with fixed . Detailed discussion of these subroutines are presented in Session 3.

2.3 non-uniqueness and fast-adaptation

For neural networks , there exists many local optimums that achieves equally good performance for each task. We observe from empirical study [Appendix] that the key to fast-adaptation for gradient-based algorithm is to find a small neighbouring zone where most tasks have at least one best parameter inside it(Figure 1(a)). The intuition is that when are close enough they can be learned within a few gradient steps starting from any points within that neighbouring zone(our experiment shows that a perturbation of initial points within that area would still have good performance at meta-test). The existence of this small neighbouring zone

depends on the parametric model

and the task distribution . We further demonstrate [Appendix] its existence with Gaussian parameterization of for uni-modal task distribution like sinusoidal functions and neural networks. Even if we fail to find a single small neighbouring zone (e.g. multi-modal task distribution like mixture of sinusoidal, linear and quadratic functions), solution may be provided by extension to mixture Gaussian grant2018modulating; rasmussen2000infinite. In this work we focus on the uni-modal situation and leave the extension to future work.

1 Algorithm Meta-train()
2        randomly initialize
3        while not done do
4               Sample batch of tasks for each task  do
5                      Sample
6                      Compute by Subroutine Meta-Gradient(,)
7               end for
10        end while
Algorithm 1 Extended Empirical Bayes Meta-learning Framework
1 Algorithm Meta-test()
2        Require: learned , from new task
3        Compute posterior VI(, ).
4        Sample
5        return for evaluation
1 Subroutine VI(,)
2        Initialize at .
3        while not done do
4               Sample from .
6        end while
7       return
Algorithm 2 VI: reparameterize using a differentiable transformation of an auxiliary noise variable such that with kingma2013auto

3 Method

In this section, we first introduce the gradient-based variational inference subroutine VI related to a variety of existing methods, then present our proposed subroutine Meta-Gradient inspired by Gradient-EM algorithm and compare it with the mostly used existing methods for this subroutine.

3.1 Variational Inference

Notice that this framework requires computing posterior on complex models such as neural networks. We approximate the posterior with the same parametric distribution as we approximate the prior and use Variational Inference to compute the parameters, as has been done in previous work ravi2018amortized. Let be the approximation of the posterior by minimizing their KL distance. Since


is constant in terms of , we have . So the inference process is to find to maximize the Evidence Lower Bound via mini-batch gradient descentravi2018amortized. The gradient of KL-divergence terms are calculated analytically in Gaussian case whereas the gradient of expectations can be computed by monte-carlo with reparameterization along with some variance reduction tricks kingma2015variational; zhang2018single. Due to the above analysis in Section 2.3, only a few gradient steps are needed for this process with well learned by our framework. We summarize the subroutine VI in Algorithm 2.

A special case worth mentioning is when we use delta function for the posterior approximation , we have , which is actually the inner-update step of iMAML rajeswaran2019meta, MAML finn2017model, and reptile nichol2018first (if we replace the l2 regularization term with choosing as initial point for gradient based optimization: ).

3.2 Meta-Gradient

The essential part of this meta-learning framework is to compute the gradient . We reduce this problem to compute given and . For , this is direct. For , we have two approaches. In the first approach we compute as stated above, then , where can be computed by auto-gradient (if is computed by gradient based algorithms). This approach is widely used in previous work such as finn2017model, finn2018probabilistic, yoon2018bayesian, grant2018recasting. The second approach we proposed is implemented in subroutine Meta-Gradient:GEM-BML+ below. We utilize a property (proof in [Appendix]) such that can be expressed by the difference of two terms.

We propose an efficient way to compute through gradient of the complete log likelihood. This is guaranteed by the following Gradient-EM Theorem inspired by the observation in salakhutdinov2003optimization.

Theorem 2.


Under hierarchical modeling structure, we have . Using Theorem 1 we have . Using VI to compute the approximate posterior parameter we get the GEM estimator which can be calculated analytically in Gaussian case as we show in Appendix. This gives us two Meta-Gradient subroutines GEM-BML and GEM-BML+ for and respectively. We name Algorithm 1 with these two subroutines as our algorithms GEM-BML and GEM-BML+.

1 Subroutine Meta-Gradient:GEM-BML(,)
2        Compute posterior VI(, ).
3        Compute posterior VI(, ).
4        return
1 Subroutine Meta-Gradient:GEM-BML+(,)
2        Compute posterior VI(, ).
3        Compute posterior VI(, ).
4        return

As comparison, one of the most widely used method to optimize in Bayesian meta-learning is optimizing ELBO ravi2018amortized(see Appendix for other existing methods and comparing analysis). Here we show it is actually another way to estimate . According to equation (2), when VI approximation error is small enough, we have . So the gradient can be computed by . The first partial gradient term can be computed by the same method in Section 3.1 and the second one can be calculated analytically in Gaussian case.

In fact, Gradient-EM can also be reviewed as an co-ordinate descent algorithm to optimize ELBO as a variant of EM as we show in Appendix. Comparing to ELBO gradient, Gradient-EM avoids the backProps computation of which gives it a series of advantages as we specify in Section 4.2. Both GEM and ELBO gradient has estimation error arise from the discrepancy of estimated posterior by VI and the true posterior. We show empirical results in Appendix that GEM has stably lower estimation error than ELBO gradient. We also show in Appendix that our method has a theoretical bound of estimation error in terms of the VI discrepancy where is a bounded constant.

4 Analysis

We compare Gradient-EM (our method) with ELBO-gradient over two loss functions , summarized in Table 1. It turns out each element of this matrix is related to a related work/our method. Here we only show how MAML and Reptile can be fit into this Bayes frame, while further details are left to [Appendix]. To see this, consider using fixed variance parameters for both prior and posterior and let so posterior becomes delta distribution . We can compute by gradient descent from . MAML uses as meta-loss function. Under delta distribution posterior . Then can be directly computed through back-propagation in neural networks. One the other hand, Reptile uses as meta-loss function. Using GEM-gradient , which is the Reptile gradient. In this sense, Gradient-EM is the Bayesian version of Reptile. Also notice that, if we let and so the prior becomes delta, then . This corresponds to "pre-train" which simply train a model to fit data of all tasks combined. Previous work has shown empirically that this method doesn’t perform well finn2017model. Here we can provide a theoretical explanation. We show in Appendix that is a biased estimator when while unbiased in Reptile where .

ELBO gradient Amortized BML ravi2018amortized related to PMAML finn2018probabilistic
Graident-EM GEM-BML (our method) ; reduce to Reptile nichol2018first in delta case KL-Chaser Loss(related to l2-Chaser Loss, BMAML yoon2018bayesian)
Table 1: Matrix of related works

Observe that all methods in the above matrix requires to compute the posterior parameters first and use it to compute the sampled meta-loss function gradient . Following the convention of finn2017model, we define the step of computing as inner-update and the step of computing as meta-update. Notice that both the column and the ELBO-gradient row involves the computation of . This means the inner-update computation of these three methods(highlighted in colour) has to compute backpropagation through the inner optimization process which leads to a number of burden and limitation while our method avoids this computation and thus gives a number of advantages as mentioned in Introduction. Also notice that, if assuming independence between neural network layers, the meta-update of our algorithm (Line 4 of Subroutine GEM-BML(+)) can be computed among different neural network layers in parallel, which may largely reduce the computation time in deep neural networks. We summarize a detailed analysis of our advantages to Appendix.

5 Experiment

5.1 Regression

The purpose of this experiment is to test our methods on fast adaptation ability and robustness to meta-level uncertainty.

We compare our model GEM-BML and GEM-BML+ with MAML finn2017model, Reptile nichol2018first and Amortized BMLravi2018amortized on the same sinusoidal function regression problem. We first apply the default setting in finn2017model then apply a more challenging setting which contains more uncertainty as proposed in yoon2018bayesian to demonstrate the robustness to meta-level uncertainty. Data of each task is generated from with amplitude A, frequency w, and phase b as task parameter and observation noise

. Task parameters are sampled from uniform distributions

and observation noise follows . ranges from . For each task, observations(

pairs) are given. The underlying network architecture(2 hidden layers of size 40 with RELU activation) is the same as

finn2017model to make a fair comparison. Since our model is a probabilistic model, we use the expectations of model weights in the inference phase for evaluations.

In Figure 2 (a), we plot the mean squared error (MSE) performance on test tasks during meta-test process under both settings. Under default setting, our methods show similar fast-adaptation ability as previous methods. The challenging setting result shows that Bayesian methods GEM-BML(+) and Amortized BML can still extract information in high uncertainty environment while non-Bayesian models MAML and Reptile fail to learn. We also observe that our model provides a stable meta-train learning curve and continues to improve with more gradient steps without overfitting. This demonstrates the robustness of Bayesian methods resulted from its probabilistic nature and the ability to control overfitting.



Figure 2: (left two) Sinusoidal regression results: Meta-test error of default and challenging setting after 40000 meta-train iterations. (right two) Computation and memory trade-offs with 4 layer CNN on 1-shot,5-class miniImageNet task.
Omniglot 1-shot, 5-class 5-shot, 5-class 1-shot, 20-class 5-shot, 20-class
MAML 98.7 0.4 99.9 0.1 95.8 0.3 98.9 0.2
first-order MAML 98.3 0.5 99.2 0.2 89.4 0.5 97.9 0.1
Reptile 97.68 0.04 99.48 0.06 89.43 0.14 97.12 0.32
iMAML 99.50 0.26 99.74 0.11 96.18 0.36 99.14 0.1
GEM-BML+(Ours) 99.23 0.42 99.64 0.08 96.24 0.35 98.94 0.25
Table 2: Few-shot classification on Omniglot dataset. The shows 95confidence intervals over different testing tasks. All results to compare are from original literature.


Figure 3: Reinforcement Learning
miniImageNet 1-shot, 5-class MAML finn2017model 48.70 1.84 first-order MAML finn2017model 48.07 1.75 Reptile nichol2018first 49.97 0.32 iMAML rajeswaran2019meta 49.30 1.88 Amortized BML ravi2018amortized 45.0 0.60 GEM-BML+(Ours) 50.03 1.63 Predictive uncertainty ECE MCE MAML 0.0471 0.1104 Amortized BML 0.0124 0.0257 GEM-BML+(Ours) 0.0102 0.0197 Table 3: Accuracy and Predictive Uncertainty Measurement of Few-shot classification on the MiniImagenet dataset. Small ECE and MCE indicate a model is better calibrated.

5.2 Classification

The purpose of this experiment aims to answer the following questions: (1) Does our model save computation time and memory requirement by avoiding meta-update backProp as we claimed? (2) Can our methods be scaled to few-shot image classification benchmarks and achieve good accuracy and predictive uncertainty?

To study (1), we turn to Mini-ImageNet

ravi2016optimization dataset on 1-shot,5-class. We compare GEM-BML+(GEM-BML is even less expensive) with MAML and its first order variants foMAML, Reptile, iMAML, Amortized BML and BMAML in Fig 2(b). Just like other first-order meta-learning algorithms and iMAML which decouples the inner-update and meta-update, the memory usage of GEM-BML(+) is independent of the number of inner-update gradient steps since the inner-update computation other than final step results need not to be stored. On the other hand, MAML-like algorithms (MAML, Amoritized BML) need memory growing linearly with inner-update gradient steps. It is also similar for compute time, MAML-like algorithms requires expensive backProp over the inner-update optimization in meta-update, where the compute cost grows at a faster rate than GEM-BML(+), foMAML, Reptile and iMAML (which has a relatively high base compute cost because of Hessian computation).

To study (2) we applied our method to N-class image classification on the Omniglot dataset and MiniImagenet dataset which are popular few-shot learning benchmarks(vinyals2016matching; santoro2016meta; ravi2016optimization). Notice that the purpose of this experiment is not to compete with state-of-the-art on this benchmark but to provide an apples-to-apples comparison with prior works within our extended Empirical Bayes framework. So for a fair comparison, we use the identical backbone convolutional architecturefinn2017model as these prior works. Note however that this backbone architecture can be replaced with other ones and lead to better results for all algorithms chen2019closer; kim2018auto. We leave to the future work to improve our method with better backbone architectures to challenge the state-of-the-art of this benchmark. The inner-update is computed using Adam to demonstrate the flexibility of our methods in choosing inner-update optimizer. The results in Table 2 shows that our methods performs as good as the the best prior methods within our extended Empirical Bayes framework.

Predictive uncertainty is the probability associated with the predicted class label which indicates how likely the predictions to be correct. To measure the predictive uncertainty of the models, we use two quantitative metrics ECE and MCE (

naeini2015obtaining; guo2017calibration) to MiniImagenet dataset. Smaller ECE and MCE indicate a better-calibrated model. A perfectly calibrated model achieves 0 for both metrics. The results of ECE and MCE for our models and previous works are shown in Table 3. We can see that our model is slightly better calibrated compared to the state-of-art bayesian meta-learning model Amortized BML and well outperform non-Bayesian models. This shows our model can learn a good prior and make good probability predictions as an advantage of Bayesian model.

5.3 Reinforcement learning

We test and compare the models on the same 2D Navigation and MuJoCo continuous control tasks as are used in finn2017model. See Appendix for detailed descriptions on experiment settings and hyper-parameters.

For a fair comparison, we use the same policy network architecture as finn2017model with two hidden layers each with 100 ReLU units. At meta-train, we collect samples of rollout of current policy and another samples rollout after 1 policy gradient update as finn2017model. At meta-test, we compare adaptation to a new task with up to 3 gradient updates, each with 40 samples. We compare to two baseline models: MAML and reptile.

MAML uses TRPO in meta-update to boost performance while our meta-update is data-free as specified in the above sections. For inner-updates, due to our model’s flexibility of choosing inner-update optimzier, we can either use vanilla policy gradient (REINFORCE) (Williams, 1992) or a specially designed TRPO proposed by finn2017model. We find that TRPO inner-update performs better in 2d navigation while vanilla policy gradient tend to be better in MuJoCo continuous control tasks. We hypothesis that the reasons could be in complex task setting the task distribution variance tend to be higher( is larger in Figure 1 (a)). While TRPO limits the step size of each inner-update which makes the task parameters hard to be attained within a few gradient steps.

As shown in Fig 3, GEM-BML+ outperforms MAML while reptile and GEM-BML has less superior performance. This shows variant is necessary in RL which has high in-task variance and easily overfitted. Previous work nichol2018first show it is hard to adapt algorithms to RL with the advantage of data-free meta-update(reptile like algorithm). But with our variant we can adapt to RL while preserving this advantage. Our results show that for RL, the key to adaptation is variant.

In Appendix, we also provide the results on multi-arm bandit task, where we observe similar superiority of our methods which demonstrate the advantage of Bayesian methods in exploration.

6 Related Works

Hierarchical Bayes(HB) and Empirical Bayes(EB) have been decently studied heskes1998solving in the past to utilize statistical connections between related tasks. Since then, deep neural network(DNN) caught enormous attention and efforts of measuring the uncertainty of DNN also started ongoing in which Bayesian and sampling method are widely applied.blundell2015weight

The research trend of multi-task learning and transfer learning also changed to the fine-tuning framework for DNN after then. Model Agnostic Meta-learning(MAML)

finn2017model emerged in such a motivation to find good initial parameters that can be fast adapted to new tasks in a few gradient steps. Recently, Bayesian models have a big comeback because of their probabilistic nature in uncertainty measure and automatic overfitting preventing. wilson2007multi applied HBM to multi-task reinforcement learning. grant2018recasting related MAML to Hierarchical Bayesian model and proposed a Laplace approximation method to capture isotopic Gaussian uncertainty of the model parameters. yoon2018bayesian used Stein Variational Gradient Descent(SVDG) to obtain posterior samples and proposed a Chaser Loss in order to prevent meta-level overfitting. finn2018probabilistic also proposed a gradient-based method to obtain a fixed measure of prior and posterior uncertainty. ravi2018amortized proposed a MAML-like variational inference method for amortized Bayesian meta-learning. All of the methods above can not make inner-update and meta-update separable thus largely limit the flexibility of the optimization process of inner-update. rajeswaran2019meta propose an implicit gradients method for MAML which can make inner-update and meta-update separable with the cost of computation on second order derivatives and solving a quadratic optimization problem in each inner-update step.

7 Conclusion

Inspired by Gradient-EM algorithm we have proposed GEM-BML(+) Algorithm for Bayesian Meta-learning. Our method is based on a theoretical insight of the Gradient-EM Theorem and the Bayesian formulation of multi-task meta-learning. This method avoids backProp in meta-update and decouples the meta-update and inner-update. We have tested our method on sinusoidal regression, few-shot image classifications and reinforcement learning to demonstrate the advantage of our method. For future work, we consider to apply our method to start-of-art image classification backbone and extending our work to nonparametric Gaussian approximation to handle multimodal and dynamic task-distribution situations.