ST-MAML: A Stochastic-Task based Method for Task-Heterogeneous Meta-Learning

09/27/2021
by   Zhe Wang, et al.
University of Virginia
1

Optimization-based meta-learning typically assumes tasks are sampled from a single distribution - an assumption oversimplifies and limits the diversity of tasks that meta-learning can model. Handling tasks from multiple different distributions is challenging for meta-learning due to a so-called task ambiguity issue. This paper proposes a novel method, ST-MAML, that empowers model-agnostic meta-learning (MAML) to learn from multiple task distributions. ST-MAML encodes tasks using a stochastic neural network module, that summarizes every task with a stochastic representation. The proposed Stochastic Task (ST) strategy allows a meta-model to get tailored for the current task and enables us to learn a distribution of solutions for an ambiguous task. ST-MAML also propagates the task representation to revise the encoding of input variables. Empirically, we demonstrate that ST-MAML matches or outperforms the state-of-the-art on two few-shot image classification tasks, one curve regression benchmark, one image completion problem, and a real-world temperature prediction application. To the best of authors' knowledge, this is the first time optimization-based meta-learning method being applied on a large-scale real-world task.

READ FULL TEXT VIEW PDF
07/05/2021

Meta-learning Amidst Heterogeneity and Ambiguity

Meta-learning aims to learn a model that can handle multiple tasks gener...
03/23/2022

Multidimensional Belief Quantification for Label-Efficient Meta-Learning

Optimization-based meta-learning offers a promising direction for few-sh...
02/20/2021

Meta-Learning Dynamics Forecasting Using Task Inference

Current deep learning models for dynamics forecasting struggle with gene...
09/04/2019

Meta Learning with Relational Information for Short Sequences

This paper proposes a new meta-learning method -- named HARMLESS (HAwkes...
02/12/2020

Distribution-Agnostic Model-Agnostic Meta-Learning

The Model-Agnostic Meta-Learning (MAML) algorithm <cit.> has been celebr...
07/08/2022

MACFE: A Meta-learning and Causality Based Feature Engineering Framework

Feature engineering has become one of the most important steps to improv...
06/12/2020

Attentive Feature Reuse for Multi Task Meta learning

We develop new algorithms for simultaneous learning of multiple tasks (e...

1 Introduction

Meta-learning aims to train a model on multiple machine learning tasks to adapt to a new task with only a few training samples. Optimization-based meta-learning like model-agnostic meta-learning (MAML) facilitate such a goal by involving the optimization process. For example, MAML trains a global initialization of model parameters that are close to the optimal parameter values of every task  

Finn, Abbeel, and Levine (2017). Recent methods expand MAML’s ”global initialization” to a notion of ”globally shared knowledge,” including not only initialization Finn, Abbeel, and Levine (2017); Li et al. (2017); Rajeswaran et al. (2019) but also update rules Andrychowicz et al. (2016); Ravi and Larochelle (2017). The globally shared knowledge are explicitly trained and allow these methods to produce good generalization performance on new tasks with a small number of training samples.

Most optimization-based meta-learning algorithms assume all tasks are identically and independently sampled from a single distribution Andrychowicz et al. (2016); Finn, Abbeel, and Levine (2017); Li et al. (2017); Ravi and Larochelle (2017); Rusu et al. (2018). This setup is known as task homogeneity. We name meta-learning’s target task distribution as “meta-distribution”. Real-world tasks, however, may come from multiple meta-distributions. For instance, autonomous driving agents need to be able to handle multiple learning environments, including those under different lighting, various weather situations, and a diverse set of road shapes. This more challenging setup, we call task heterogeneity, posts technical challenges to strategies like MAML  Vuorio et al. (2019).

For task heterogeneity setup, a naive and widely accepted meta-learning solution first learns a globally shared initialization across all meta-distributions and then tailors the model parameter to the current task Vuorio et al. (2019); Yao et al. (2020, 2019); Lee and Choi (2018); Oreshkin, Rodriguez, and Lacoste (2018). The tailoring step needs to rely on the task-specific information or, ideally, the identity information of the task. It, therefore, requires the meta-learner to infer the potential identity of a new task from a limited number of annotated samples Finn, Xu, and Levine (2018). This requirement raises severe uncertainty issues – a challenge known as ”task ambiguity.” Figure 1 provides a concrete example of ”task ambiguity” that attributes to not only the limited annotated data but also from the multiple distributions that a task may come. Surprisingly, recent optimization-based meta-learning literature pay little attention to the task ambiguity challenge  Vuorio et al. (2019); Yao et al. (2020, 2019); Lee and Choi (2018).

Figure 1: Two critical challenges in meta-learning. (a, b): The figures show the difference between task homogeneity and task heterogeneity in meta-learning. The solid line with arrow represents the uniformly random sampling from meta distributions (inner circle). (c, d): The figures demonstrate the task ambiguity in meta-learning. In heterogeneous setup, the task ambiguity is more critical due to the distributional uncertainty. The red dots represent the available training data, the dashed and solid curves are potential explanations of the data (better read in color).

This paper proposes a novel meta-learning method ST-MAML for task heterogeneity challenge and centers our design on solving the task ambiguity issue. Our approach extends MAML by modeling tasks as a stochastic variable that we name as stochastic task. Stochastic task allows us to learn a distribution of models to capture the uncertainty of an ambiguous new task. We use variational inference as solver and the whole learning process does not require knowing the cardinality of meta-distributions. We apply the ST-MAML on multiple applications, including image completion, few-shots image classification, and temporal forecasting meta-learning problems. To the best of authors’ knowledge, this is the first time optimization-based meta-learning being applied on a large-scale real-life task. Our empirical results demonstrate that ST-MAML outperforms the MAML baselines with on that task.

2 Methods

Figure 2: Probabilistic model overview of ST-MAML .
Algorithm 1 ST-MAML Meta-Training Procedure. 1:  Input: Meta-distributions , Hyper-parameters and . 2:  Randomly initialize model parameter , stochastic task module parameters , tailoring module parameters , input encoding parameters . 3:  while not DONE do 4:     Sample batches of tasks from meta-distributions. 5:     for every task  do 6:        Infer the posterior distribution of stochastic task variable and sample . [eq.(8) and eq.(10)] 7:        Tailor with sample to get task-specific initialization . [eq.(11)] 8:        Revise the encoding of input variable by augmenting the raw input. [eq.(12)] 9:        Evaluate the inner loss on training set . [eq.(16)] 10:        Compute adapted parameter and augmented feature with gradient descent [eq.(17)]: . 11:     end for 12:     Update with . [eq.(15)] 13:  end while

2.1 Preliminaries on Meta Learning

We describe a supervised learning task in meta-learning as

(1)

Here , which takes as input model

and dataset, describes the loss function that measures the quality of learner

, whose parameter weight is . Every task includes an annotated training set and a test set . During meta-training, the test set is fully observed, but during meta-testing only its input is available. and are sampled from , describes the input space and is the output space.

The goal of meta learning is that on every task, the learner machine needs to perform well on after fine-tuning on this task’s training set . MAML Finn, Abbeel, and Levine (2017) achieves such a goal by learning a globally shared weight initialization that is close to the optimal weight parameter of every task. We can write its training objective for getting the best initialization as:

(2)

MAML samples a set of tasks from the meta distribution and initialize each task’s weight from the global knowledge (to be learnt): i.e., setting . On each task, the learner performs gradient descent on its training set to reach task-specific fine-tuned parameters . The test set of task is used for evaluating the current parameter , and the evaluation will be used as the objective to optimize for learning the best global knowledge .

The above objective (in Eq. (2)) can be equivalently framed as maximizing the likelihood :

(3)
(4)

where is a Dirac distribution derived by minimizing the negative log-likelihood(NLL) on with gradient descent.

2.2 Previous Heterogeneous Meta Learning

Task-homogeneous meta-learning assumes that there exists one meta-distribution and all tasks are identically and independently (i.i.d.) sampled from . Differently, in a task-heterogeneous setup, there exist multiple meta-distributions . Figure 1 (a,b) compare two described meta-learning setups.

We can naively use MAML and assign all tasks with the same global initialization (though they come from different distributions). Figure 1(c, d) show that the ”task ambiguity” issue is more critical in task-heterogeneous setup and will hinder the generalization from MAML initialization since multiple very different task distributions exist.

A handful of previous works learn a customized initialization that was tailored from global initialization, in order to tackle the task heterogeneity challenge. MMAML Vuorio et al. (2019) learns a deterministic task embedding with an RNN module. HSML Yao et al. (2019) manually designs a task clustering algorithm to assign tasks to different clusters, then customizes the global initialization to each cluster. ARML Yao et al. (2020) models global knowledge and task-specific knowledge as graphs; the interaction between tasks is modeled by message passing.

Surprisingly, none of the recent works consider the task ambiguity issue. Most frameworks are still based on the assumption that only one distribution exists to explain a task’s observed training set (e.g., a new task should be assigned to only one cluster in HSML). The potential identities of a task can be highly uncertain under the limited annotated data scenario. Figure 1(d) shows that the explanation of the observation can be various in task-heterogeneous setup and we should not expect to obtain a unique predictor.

2.3 Stochastic Variable to Encode Task

When facing the task-heterogeneous setup, we hypothesize that a meta-learner that can encode potential tasks’ patterns will alleviate the task ambiguity issue (to some degrees). These patterns could describe valuable information about tasks like the more possible shapes of curves for a regression meta-application. Moreover, we propose to enable task encoding with uncertainty estimates. This is because learning a task representation from its limited annotated data is challenging and such uncertainty measures can help inform the downstream meta-adaptation to new tasks (see Figure 

1(d)).

This hypothesis motivates us to describe a task with a stochastic variable and model its distribution to condition on observations. With adding this latent variable, we can rewrite the per task likelihood in Eq. (3) as:

(5)

We assume in the second term from above, only conditions on . Figure 2 shows our design.

In later Section (2.5), we show that due to the intractable likelihood as defined above, we choose to maximize its evidence lower bound (a.k.a ELBO) instead. Optimizing this variational objective requires the prior and the posterior . We model the prior

as a Gaussian distribution, whose mean and variance are outputs from a two-layer multi-layer perceptron (MLP) module with input vector

:

(6)

Here vector is a vector summarizing the encoding of a task . We propose a neural network module to learn from the sample observations . The training observations of task consist of unordered annotated data pairs . Permutation invariant is a desirable property for functions acting on sets. As recommended by deep sets Zaheer et al. (2017), the authors proved any function acting on sets is permutation invariant if and only if it can be decomposed as for suitable choice of transformations . We follow such a design, and encode a task by encoding every pair of its observation in through a neural network layer:

(7)
(8)

Eq. (8) uses average function as aggregation operator to obtain the task embedding because it is able to remove the inductive bias due to different sizes of training set from . In Eq. (7), is implemented as a MLP module with learnable parameter .

We then approximate the intractable posterior distribution of as conditioned on the whole (see Section (A.2)):

(9)
(10)

where , and are the same MLP modules we have in Eq. (6).

2.4 St-Maml : Customizing Knowledge with

Now with the summary task representation , we propose to use it to revise MAML into ST-MAML for heterogeneous meta-learning setup. We propose to tailor the global initialization to task-specific initialization for a task .

There exist many potential ways to use to tailor the global initialization to task-specific initialization . We choose the following design. We assume, our target learning machine composes with a base learner and a task learner, like neural network models:

We assume the base learner’s parameter is , and its task learner’s parameter is (for instance, the last linear layer before softmax for classification case). We can then rewrite . We propose to only customize with :

(11)

Here is sampled from the distribution during meta-training and from during meta-testing.

is the sigmoid function,

represents the element-wise multiplication, are learnable parameters.

Moreover, we design additional customized knowledge for task . The basic intuition is that the final prediction of a meta-learner depends on both model parameters and input representations. To increase the capacity of the task-specific knowledge, we propose to further propagate task representation into encoding augmented feature representations we denote as . We concatenate with a sample’s input representation , and feed the combined vector to our learning machine as its new input.

(12)

Same as Eq. (11), is sampled from its distribution, are learnable parameters.

Now when facing a new task , a meta-model will first generate the task-specific knowledge that includes both augmented feature and task-specific parameter . We denote the combined knowledge set for task as:

(13)

This is the meta-knowledge we need to learn in ST-MAML . We note its initial values as and fine-tuned values as .

Aiming to learn the meta knowledge defined in Eq. (13), now we can write our objective (task likelihood) in Eq. (5) into the following factorization:

(14)

This follows the Bayesian graph provided in Figure 2.

Design Choices:

There exist many other possible probabilistic design besides Figure 2. For instance, we can model every variable in the figure as a stochastic distribution and build a complicated hybrid framework. However, it will lead to excessive stochasticity and increase the potential of the underfitting issue especially in a limited data situation. Instead, similar to , we choose to model both and as deterministic (see Eq. (11) and Eq. (12)) that allow us to employ an amortized variational inference technique Ravi and Beatson (2019).

Our design is different from recent probabilistic extensions of MAML Finn, Xu, and Levine (2018); Yoon et al. (2018). They conduct inference on model parameters (initial value or fine-tuned value ). Our ST-MAML shifts the burden of variational inference to the task representation , whose dimension is of multiple orders smaller than the size of model parameters.

2.5 St-Maml : Update Rules

Figure 3: Iterative optimization process. In the inner loop, Starting from task-specific parameter initialization and augmented features , their fine-tuned values are inferred by performing gradient descent on the training set for iterations.

Variational Objective:

To optimize the intractable likelihood as defined in Eq. (14), we choose to maximize its evidence lower bound (a.k.a ELBO) instead:

(15)

During meta-training, we sample tasks and optimize the empirical average .

Update Rules:

Same as MAML, the optimization of the ST-MAML contains two loops: the inner loop and the outer loop. Figure 3 shows the iterative optimization process. In the inner loop, for the training data, we concatenate with augmented feature to get augmented input vector . We feed into the learning machine whose parameter is to calculate the inner loss:

(16)

The inner loss is then used for updating and :

(17)

Figure 3 shows we can optimize the inner loss for iterations to achieve a closer approximation for optimal values in Eq. (16). In the outer loop, we maximize the approximated ELBO in Eq. (15) using a batch of tasks. The amortized variational technique allows us to conduct the sampling from by first sampling from and then apply deterministic transformation using Eq. (11) and Eq. (12).

Algorithm of St-Maml :

We described the procedure of ST-MAML in the form of pseudo code as shown in Algorithm 1. Note, parameters of neural functions , , , , and are updated in the outer loop.

Theoretically Analysis of St-Maml :

We also provide the second interpretation of our objective from information bottleneck perspective and prove they lead to exactly the same target. See Section (A.3) for detailed proofs.

2.6 Connecting to Related Work

Optimization-based meta-learning methods facilitate the model’s adaption to new tasks through global knowledge learned by the optimization process. Meta-LSTM Ravi and Larochelle (2017) meta-learns the update rule with an RNN meta-learner. MAML Finn, Abbeel, and Levine (2017) trains a global initialization close to the optimal value of every task. Leveraging diverse meta-knowledge further accelerates the learning process. In Meta-SGD Li et al. (2017), the meta-knowledge consists of both initialization and learning rate. ALFA Baik et al. (2020)

proposes to meta-learn both initialization and hyperparameter update module. Most methods assign the same global knowledge to every task that leads to sub-optimal solutions for heterogeneous settings. Besides, they are all deterministic and can only learn one solution for a new task.

Bayesian approaches are a long-standing discipline that incorporates uncertainty in modeling. Multiple recent works extend MAML into the Bayesian framework and recast meta-learning as the probabilistic framework Finn, Xu, and Levine (2018); Grant et al. (2018); Yoon et al. (2018); Ravi and Beatson (2019); Garnelo et al. (2018b). PLATIPUS Finn, Xu, and Levine (2018) builds upon amortized variational inference and injects Gaussian noise into the gradient during the meta-testing time to learn a distribution over model parameters. LLAMAGrant et al. (2018)

applies Laplace approximation for modeling the parameter distribution, but it requires the approximation of a high dimensional covariance matrix. These methods view model parameters (i.e. network weights and bias) as random variables and perform inference on them. It leads to significant challenges when working with complicated models and high-dimensional data.

Our work also loosely connects to the ”prototype meta-learning”  Triantafillou et al. (2019); Snell, Swersky, and Zemel (2017). These studies learn a prototype for every class we need to predict and the final prediction depends on the distances between instances and prototypes. Amortized bayesian prototype meta-learning  Sun et al. (2021) assumes a distribution over class prototypes. This design requires prior knowledge about the classes of tasks and only applies to the classification homogeneous-meta setup.

Another line of Bayesian meta-learning studies Garnelo et al. (2018b); Wang and Van Hoof (2020); Louizos et al. (2019); Kim et al. (2018) belongs to the neural approximators of the stochastic process family. They learn a prior for every task or further use a hierarchical model that learns the instance prior. However, these methods don’t share knowledge across tasks. Table 6 compares related lines of works with ours.

Problems Tasks Heterogeneity Ambiguity
Regression 2D regression
Weather prediction
Image completion
Classification PlainMulti classification 5way 5shot
CelebA binary classification 2way 5shot
(see Section (A.4))
Table 1: A summary of datasets, tasks and their properties.
Model MAML MetaSGD BMAML MMAML ARML ST-MAML
MSE
Table 2: Regression accuracy on 2D regression tasks.
Model MAML MetaSGD ST-MAML ST-MAML (w/o aug) ST-MAML (w/o tailor)
MSE
Table 3: 10-Shot temperature prediction.

Figure 4: A visualization of trained ST-MAML on the NOAA-GSOD temperature prediction task. The model is given training points (red) and predicts the remaining days of the year (orange). The true temperatures are shown in blue.

3 Experiments

We apply ST-MAML to both few-shot regression and classification to demonstrate its effectiveness on both heterogeneous and ambiguous tasks. In regression, we evaluate ST-MAML in a variety of domains including 2D curve fitting, whose tasks show both heterogeneity and ambiguity, and two real-world tasks including image completion and weather prediction. We also study two few-shot classification problems. Tasks from the Plain-Multi dataset are heterogeneous while CelebA classification uses ambiguous decision rules. A summary of experiment design, datasets, and their properties is shown in Table 1. The number of represents the significance level of the challenge.

3.1 2D Regression

Setup. For 2D regression, we follow the similar setting as Yao et al. (2020), where . The meta distribution consists of 6 function families including sinusoids, straight line, quadratic, cubic, quadratic surface, and ripple

functions. To increase ambiguity, we perturb the output by adding a Gaussian noise whose standard deviation is 0.3. During meta-training, every task is uniformly randomly sampled from one of them, the size of the training set

. A detailed description of the setup and model architecture is available in appendix (see Section (A.5)).

Baselines, results, and analysis. We have two types of baselines: (1) meta-learning methods designed for homogeneous tasks: MAML Finn, Abbeel, and Levine (2017) and MetaSGD Li et al. (2017). (2) Bayesian meta-learning method: Bayesian MAML Yoon et al. (2018), which conducts inference on a large number of model parameters. (3) Meta-learning methods designed for heterogeneous tasks including MMAMLVuorio et al. (2019) and ARMLYao et al. (2020). We train our model on around tasks and evaluate it on over new sampled tasks. The results are summarized in Table 2. We showcase fitting curves in appendix (see Figure 6). Even though we fix the size of the training set and noise level for every task during meta-training, during meta-testing, they are flexible and can be changed. To increase the ambiguity for every test task, we vary the number of available annotated data in the training set and noise level. More analysis visualization results can be found in appendix A.5.

As in Figure 6, all sampled solutions will be close to the groundtruth if tasks are less uncertain. On the other hand, the figures in appendix A.5 show that as tasks become more ambiguous, due to fewer annotated training data or larger noise, the sampled solutions tend to span wider space.

3.2 Temperature Prediction

Setup. Next, we evaluate the model in a challenging regression problem using real-world data. The NOAA Global Surface Summary of the Day (GSOD) dataset contains daily weather data from thousands of stations around the world. Each task is created by sampling data points from (station, year) pairs. The model takes in the current day of the year along with weather features such as wind speed, station elevation, precipitation, fog, air pressure, etc. It then learns to predict the average temperature in Fahrenheit on that day. We remove important information like the weather station number, name, latitude, and longitude. Hiding the station information in this way creates a highly heterogeneous problem where each station generates its own task distribution. The model sees days of labeled temperature data before predicting the temperature on test days. More technical details can be found in appendix A.5.

Results and analysis. After epochs of training on approximately unique (station, year) tasks, we evaluate the model on a test set of (station, year) pairs. The results are summarized in Table 3. The MSE error of MAML is close to double that of ST-MAML . MetaSGD, designed for homogeneous meta-learning, achieves low accuracy because the globally learned learning rate will hurt the model’s generalization ability on unseen tasks from different distributions. It is consistent with our assumption that incorporating task-specific knowledge into the model can help solve the task-heterogeneous challenge.

Settings Algorithms Data: Bird Data: Texture Data: Aircraft Data: Fungi
5-way 5-shot MAML
MetaSGD
BMAML
MMAML
HSML
ST-MAML
ST-MAML (w/o aug)
ST-MAML (w/o tailor)
Table 4:

5-way 5-shot classification accuracy with 95% confidence interval on Plain-Multi dataset.

3.3 Image Completion

Setup. We also apply our method to image completion tasks. In image completion, the meta distribution . Every task contains one image of size sampled randomly from one of three distributions. In meta-training, pixels are observed for every image, thus, . We use coordinates as inputs and pixel value as the target variable. Detailed architecture can be found in appendix A.5.

Figure 5: Visualization of completed images. First column contains original images, second column shows the observations which contains only annotated pixels(left) and 40 annotated pixels(right). The unobserved pixels have been coloured blue for better clarity. The remaining columns correspond to different samples given the context points.

Baselines, results and analysis. Image completion with limited given pixels is a benchmark task for Neural processes Garnelo et al. (2018a, b). Thus, we compare our proposed ST-MAML with neural processes(NP) Garnelo et al. (2018b) and conditional neural processes Garnelo et al. (2018a) which is viewed as deterministic neural processes. Similar to CNP, we also recast our model into the deterministic framework, where the task representation is modeled as a fixed-dimension vector learned from the training set only. The numerical comparison is shown in Table 5. ST-MAML achieves higher completion precision compared with NP and CNP. We skipped the variance for all methods because the difference is insignificant and is close to .

There is a large amount of ambiguity surrounding the completed images. Given limited observed pixels, multiple potential images are lying behind, especially for gray images. Uncertainty arises on three levels: inter-class level, inter-distribution level, and cross-distribution level. ST-MAML can increase the opportunity of capturing more potential truths by learning a distribution of possibilities rather than a unique mapping. We visualize observations and their completions in Figure 5. Our set operations allows us to learn from any size of the training set during meta-testing. Thus, as more pixels are observed during meta-testing, the task is less ambiguous. Therefore, the completed images from different models stay close to the original image.

Model NP CNP ST-MAML (deter) ST-MAML
BCE
Table 5: Image completion accuracy.

3.4 Heterogeneous Classification

Setup and baselines. N-way K-shot classification is a popular setup in few-shot meta-learning Chen et al. (2019); Ren et al. (2018); Vinyals et al. (2016). The training set of every task consists of classes with labeled data in each class. We apply our proposed ST-MAML on the benchmark heterogeneous meta-learning dataset: Plain-Multi, proposed in Yao et al. (2019). The meta-distributions consists of four datasets and every task is sampled uniformly randomly from one of them. Following the benchmark architecture, the feature learner contains four convolutional blocks. The input is feed into two convolutional blocks with 6 channels, then the output is appended with the target variable and passed into a two-layer MLP module to model the mean and variance of . We compare to MAML Finn, Abbeel, and Levine (2017), MetaSGD Li et al. (2017), MMAML Vuorio et al. (2019), HSML Yao et al. (2019), and probabilistic method BMAML Yoon et al. (2018).

Results and analysis. Trained on over tasks, the model is evaluated on tasks for each dataset and the results are summarized in Table 4. The most relevant method is MMAML. It learns a deterministic task embedding with an RNN module and encodes all parameters in both base learner and task learner . Our method outperforms it on every dataset. Also, the probabilistic framework enables us to achieve consistently low variance. HSML requires the prior knowledge about number of clusters, which plays an important role with respect to the final accuracy.

3.5 Ablation Studies.

Facing a task, the initial state of the knowledge set includes both tailored initialization and augmented feature. To better investigate the contribution of each component, we perform ablation experiments on both temperature prediction and PlainMulti classification. The results are shown in both Table 3 and Table 4. Both two types of task-specific knowledge exhibit the performance improvement over the baselines, and they together give the best performance.

4 Conclusion

Task heterogeneity and task ambiguity are two critical challenges in meta-learning. Most meta-learning methods assign the same initialization to every task and fail to handle task heterogeneity. They also disregard the task ambiguity issue and learn one solution for every task. ST-MAML encodes tasks using NN-based stochastic task module plus set-based operation for permutation-invariance. This stochastic task design allows for customizing global knowledge with learned stochastic task distribution. We further convert latent task encodings to augmented features to improve the interaction between model parameters and input variables. The probabilistic framework allows us to learn a distribution of solutions for ambiguous tasks and recover more potential task identities. Empirically, we design extensive experiments on regression and classification problems and show that ST-MAML provides an efficient way to learn from diverse and ambiguous tasks. We leave the challenge to handle domain generalization during meta-testing to future work.

References

Appendix A Appendix

a.1 Model Comparison.

Category Tasks Knowledge Set Tailoring Sampling Inference on
HoMAMLs MAML Finn, Abbeel, and Levine (2017) Initialization
MetaSGD Li et al. (2017) Initializationlr
HeMAMLs MMAML Vuorio et al. (2019) Initialization
HSML Yao et al. (2019) Initialization
NPs NP Garnelo et al. (2018b) Aug feature Representation
CNP Garnelo et al. (2018a) Aug feature
PMAMLs BMAML Yoon et al. (2018) Initialization Parameters
PLATIPUS Finn, Xu, and Levine (2018) Initialization Parameters
ST-MAML InitializationAug feature Representation

Table 6: Model comparison table. HoMAMLs are MAMLs designed for task homogeneity, and HeMAMLs are for heterogeneity. NPs describe methods in Neural Processes family. PMAMLs mean probabilistic extensions of MAML. Aug feature represents the augmented features.

a.2 Approximation for posterior distribution .

Given the training set of a task , the stochastic task variable is supposed to infer its posterior distribution conditioned on only, specifically, we have the true posterior:

(18)

the empirical distribution is only known in the form of pairs. Thus, the true posterior distribution is intractable. Based on our design, we suppose the prior distribution is a multivariate Gaussian distribution, whose mean and variance is the output of a set operator acting on pairs. To ensure the posterior stays close to the prior, also the posterior is derived from , we approximate it with the output of the same set operator acting on both and pairs.

a.3 Derivation of ELBO approximation as Variational Information Bottleneck Objective

For task , our fine-tuned task-specific knowledge set contains two variables: model parameters and augmented features . Given task inputs , we are seeking a task-specific knowledge set that is maximally informative of test target , while being mostly compressive of training target  Titsias, Nikoloutsopoulos, and Galashov (2020); Tishby, Pereira, and Bialek (2000). Correspondingly, we would like to maximize the conditional mutual information and minimize . The information bottleneck objective is:

(19)

We show the following lemma in appendix A.3:

Lemma 1

Given a task , maximizing the information bottleneck loss defined in (19) is equivalent to maximizing the weighted ELBO :

(20)
Proof 1

To lower bound IB objective defined in Eq. (19), we derive the lower bound for first term and upper bound for second term . Further, we assume a distribution as a variational approximation of the true distribution .

(21)
(22)
(23)

The last part follows from the fact that is independent of given . Putting this together:

(24)

However, the above conditional distribution is intractable due to the unknown data distribution . To derive the upper bound, we introduce a variational approximation for .

Take it into the Eq. (21), we have:

(25)

In the above equation, we use in the second step.

The second term is irrelevant to our objective so we can treat it as a constant. Note that:

(26)

Thus, an unbiased estimation of the first term is:

(27)

We derive the upper bound for second term:

(28)

The denominator is intractable for unknown . We approximate it with . With similar derivation, the second term is upper bounded by:

(29)

Similarly, its unbiased estimation is given as:

(30)

Combining two terms, we get the total unbiased estimation of the IB loss:

(31)

To incorporate target information, we inject the target variable into posterior and into prior, and get the new approximation:

(32)

Since , where are both deterministic and invertible mappings of , we have . Moreover, are conditionally independent given . Similarly, are deterministic function of and . Thus, the second term in Eq. (32) can be replaced with the divergence between the posterior and prior distribution of , i.e. .

We know look into the log likelihood term in Eq. (31). Since the transitions and are deterministic:

(33)

According to the analysis, the approximation to be optimized is: