MetaMix
MetaMix for ICML 2021
view repo
Meta-learning has proven to be a powerful paradigm for transferring the knowledge from previously tasks to facilitate the learning of a novel task. Current dominant algorithms train a well-generalized model initialization which is adapted to each task via the support set. The crux, obviously, lies in optimizing the generalization capability of the initialization, which is measured by the performance of the adapted model on the query set of each task. Unfortunately, this generalization measure, evidenced by empirical results, pushes the initialization to overfit the query but fail the support set, which significantly impairs the generalization and adaptation to novel tasks. To address this issue, we include the support set when evaluating the generalization to produce a new meta-training strategy, MetaMix, that linearly combines the input and hidden representations of samples from both the support and query sets. Theoretical studies on classification and regression tasks show how MetaMix can improve the generalization of meta-learning. More remarkably, MetaMix obtains state-of-the-art results by a large margin across many datasets and remains compatible with existing meta-learning algorithms.
READ FULL TEXT VIEW PDFMetaMix for ICML 2021
Meta-learning, or learning to learn [38], empowers agents with the core aspect of intelligence–quickly learning a new task with as little as a few examples by drawing upon the knowledge learned from prior tasks. The resurgence of meta-learning recently pushes ahead with more effective algorithms that have been deployed in areas
such as computer vision
[16, 22, 37][6, 13, 25], and robotics [44, 49]. Some of the dominant algorithms learn a transferable metric space from previous tasks [30, 35, 41], unfortunately being only applicable to classification problems. Instead, gradient-based algorithms [7, 9, 21] framing meta-learning as a bi-level optimization problem are flexible and general enough to be independent of problem types, which we focus on in this work.The bi-level optimization procedure of gradient-based algorithms is illustrated in Figure 3. In the inner-loop, the initialization of a model globally shared across tasks (i.e., ) is adapted to each task (e.g., for the first task) via gradient descent over the support set of the task. To reach the desired goal that optimizing from this initialization leads to fast adaptation and generalization, a meta-training objective evaluating the generalization capability of the initialization on all meta-training tasks is optimized in the outer-loop. Specifically, the generalization capability on each task is measured by the performance of the adapted model on a set distinguished from the support, namely the query set.
The learned initialization, however, is at high risk of overfitting the meta-training tasks and failing a meta-testing task. As evidenced in Figure 3 for drug activity prediction which we detail later in experiments, the prediction accuracy of the learned initialization without adaptation is exceptionally high on query sets of meta-training tasks but surprisingly low on those of meta-testing ones. Improving this generalization from meta-training to meta-testing tasks, which we call meta-generalization, is especially challenging – standard regularizers like weight decay lose their power as they hurt the flexibility of fast adaptation in the inner-loop. To this end, the few existing solutions attempt to regularize the search space of the initialization [47] or enforce a fair performance of the initialization across all meta-training tasks [15] while preserving the expressive power for adaptation.
Rather than passively imposing regularizations on the initialization, we turn towards an active approach which anticipates more data to meta-train the initialization. The intuition comes from empirical implications from Figure 3. The huge gap between the prediction accuracy of the learned initialization on query sets and that on support sets suggests that the initialization overfits the query but overlooks the support set. This initialization, obviously, runs counter to a desired one which should be generalized enough to behave consistently across any set of meta-training tasks before being generalized to support sets of meta-testing tasks. To resolve the inconsistency and thereby promote meta-generalization, an intuitive solution is to evaluate and optimize the generalization capability of the initialization with more data than the query sets only in the outer-loop.
The most immediate choice for more data is the support sets, while it is far from enough. The support sets contribute little to the value and gradients of the meta-training objective, as the meta-training objective is formulated as the performance of the adapted model which is exactly optimized via support sets. Figure 3 sheds light on this fact, where the gradient norms of the meta-training objective with regard to the initialization by support sets are much smaller than those by query sets.
Thus, in this paper, we are motivated to produce “more” data out of the accessible support and query sets. The resulting strategy we propose, MetaMix, linearly combines either original features or hidden representations of the support and query sets
, and performs the same linear interpolation between their corresponding labels. These additional signals for the meta-training objective encourage the learned initialization to be consistently generalized to a spectrum of data from the support to the query sets, thereby improving the meta-generalization as expected.
Not only do we identify this novel direction to improve meta-generalization, we also offer theoretical insights into the reason why MetaMix works in both classification and regression problems. More remarkably, throughout comprehensive experiments, we demonstrate three significant benefits of MetaMix. First, the performances are substantially improved over state-of-the-art meta-learning algorithms and the two regularizers [15, 47] in various real-world datasets. Second, better generalization to even heterogeneous tasks is achieved. Third, MetaMix is compatible with existing and advanced meta-learning algorithms and ready to give a boost to their performances.
Gradient-based meta-learning algorithms assume a set of tasks to be sampled from a distribution . Each task consists of a support sample set and a query sample set , where and denote the number of source and query samples, respectively. The objective of meta-learning is to master new tasks quickly by adapting a well-generalized model learned over the task distribution . Specifically, the model parameterized by is trained on massive tasks sampled from during meta-training. When it comes to meta-testing, is adapted to a new task with the help of the support set and evaluated on the query set .
Take model-agnostic meta-learning (MAML) [7] as an example. The well-generalized model is grounded to an initialization for , i.e., , which is adapted to each -th task in a few gradient steps by its support set . The generalization performance of the adapted model, i.e., , is measured on the query set , and in turn used to optimize the initialization during meta-training. Let and denote
the loss function and the inner-loop learning rate, respectively.
The above interleaved process is formulated as a bi-level optimization problem,(1) |
where and represent the concatenation of samples and their corresponding labels for the support (query) set, respectively. In the meta-testing phase, to solve the new task , the optimal initialization is fine-tuned on its support set to the resulting task-specific parameters .
In practical situations, the distribution
is unknown for estimation of the expected performance in Eqn. (
14). Instead, the common practice is to approximate it with the empirical performance, i.e.,(2) |
Unfortunately, this empirical risk observes the generalization ability of the initialization only at a finite set of tasks. When the function is sufficiently powerful, a trivial solution of is to memorize all tasks [47]. Such memorization leads to poor meta-generalization (see Figure 3) of to meta-testing tasks. Before proceeding to our solution for improving the meta-generalization, we would first consider what has been memorized. Since is optimized for its generalization performance on all query sets in the outer-loop, the query sets are memorized specifically.
Inspired by data augmentation [5, 50, 51]
which is used to mitigate the memorization of training samples in conventional supervised learning, we propose to alleviate the problem of task memorization via involving more data to meta-train
. Compared to supervised learning where the augmentation originates from only the memorized training samples, better yet, we also have access to the support sets besides the memorized query sets during meta-training. Besides, Figure 3 shows that , as a sign of poor meta-generalization, behaves very differently on the query and the support sets.Therefore, we are highly motivated to propose MetaMix that produces more data to meta-train by mixing samples from both query sets and support sets. The strategy of mixing follows Manifold Mixup [40] where not only inputs but also hidden representations are mixed up. Assume that the model consists of layers. The hidden representation of a sample set at the -th layer is denoted as (), where . For a pair of support and query sets with their corresponding labels in the -th task , i.e., and , we randomly sample a value of and compute the mixed batch of data for meta-training as,
(3) |
where and each coefficient . Here, we assume that the size of the support set and that of the query are equal, i.e., . If , for each data sample in the query set, we randomly select one sample from support set for mixup. In Appendix B.1
, we illustrate the Beta distribution
in both symmetric (i.e.,) and skewed shapes (i.e.,
). Using the mixed batch by MetaMix, we reformulate the outer-loop optimization problem as,(4) |
where represents the rest of layers after the mixed layer . MetaMix is flexible enough to be compatible with off-the-shelf gradient-based meta-learning algorithms, by replacing the query sets with the mixed batch for meta-training. Taking MAML with MetaMix as an example, we show meta-training and meta-testing in Alg. 1 and Alg. 2 of Appendix.
We also probe into the mechanism of MetaMix in improving the meta-regularization, by linking to generalization in the information theory [34]. In [34], the SGD optimization is suggested to have two phases, i.e., empirical error minimization and representation compression. The compression phase, taking much longer time, is accountable to generalization. Unfortunately, a few adaptation steps in the inner-loop only suffice to minimize the empirical error with regard to the support set but leave the hidden representations uncompressed. Thus, the adapted model tends to fail the query sets, so that the burden of minimizing the generalization performance in the outer-loop is placed onto which gradually overfits to the query sets. MetaMix, by incorporating the support sets for the outer-loop optimization which requires sufficient iterations, pulls
back to behave consistently between the support and the query set and further compresses the hidden representations of the support set. Specifically, we computed the largest singular value for the hidden representations of the support set after meta-training. As Table
1 shows, MAML with MetaMix achieves a reduction of the top singular values, which signifies more compressed representations and thereby better generalization to the support set [28]. with this consistent generalization across support and query is highly anticipated to generalize to meta-testing tasks.Class | C1 | C2 | C3 | C4 | C5 |
MAML | |||||
MR-MAML | |||||
MetaMix |
In this section, we theoretically study the effectiveness of MetaMix in two special scenarios in classification and regression tasks. To show a general case, we omit the task-level index for all associated symbols. The gradient and Hessian of the loss function computed on support set and query set are denoted by , , , , respectively. For brevity, we denote the gradients , , and by , , and . And similar simplification of notation is also applied on the Hessian computed by and .
We first focus on the objective Eqn. (14) without MetaMix. As the task-specific model is updated by one-step GD as: , the gradient of Eqn. (14) w.r.t. in the outer loop optimization is:
(5) |
Following the analysis in [29], we present the following lemma to study the approximation of with first and second-order Taylor expansion of around .
The gradient of Eqn. (14) can be approximated by
(6) |
Ignoring the term , the approximated gradient in Eqn. (6), , can be consider as the gradient on of a loss function defined as
(7) |
Notice that is updated by gradient-based optimizer with the gradient . When the gradient is approximated by Eqn. (6), the update of is to approximately minimize the loss function (7). There are two terms in Eqn. (7). The first term is the loss on the query set . The second term can be viewed as a regularizer to encourage the similarity between the gradients of computed on the support and query sets. In this way, the inner loop optimization process driven by the support set can approximate the outer loop gradient descent on the query set. Thus, the task-specific model after fast adaptation is expected to have satisfactory generalization performance on query set. However, the learning rate in the inner loop is often a small value, resulting in the limited effect of the second term and finally pushing the model initialization to overfit the query set.
In MetaMix, instead, we introduce the support set into the outer-loop optimization by linearly combining it with the query set, i.e., , . Denoting and its gradient as , the approximated loss function in Eqn. (7) turns to,
(8) |
Since the second term in Eqn. (8) is a regularizer to encourage the gradient computed by support sets to be similar to that of the first term, it is sufficient to focus on analyzing the first term. Here we study the effectiveness on both regression and classification scenarios. For brevity, we remove the subscript and denote the loss and of sample by and , respectively.
In classification problem, we consider a double linear loss which is linear on both and . Assuming , we expand the loss function on as:
(9) |
The full proof is available in Appendix A.2. In Eqn. (9), the first two terms are the original loss on support and query sets. By optimizing Eqn. (9), the original loss on both sets is minimized. Notice the objective can also be realized by simply combining the support and query sets in outer loop optimization. Compared to the simple combination of support and query sets in outer loop optimization, the improvements of MetaMix are mainly benefited from the third term, which can be interpreted to a cross set distillation loss and is known to improve generalization performance [14, 23].
In regression problem, we consider a least-squared loss and assume the support set and the query set are both i.i.d. sampled from the same unknown distribution and the mapping function of distribution is defined as . As and are the same by expectation, we denote both of them by . And we assume the difference between support and query sets lies in the outputs and as the mapping function to generate and are polluted by noise. To be specific, we assume
(10) |
where and are noises sampled from a zero-mean distribution. To verify the existence of Eqn. (10), we illustrate a simple case that satisfy the equation in Appendix A.3. Caused by the existing noises, training the meta-model with either support set or query set are not able to recover the genuine model . But if we optimize on the MetaMix loss function, we can recover the genuine model as:
(11) |
The deduction of the second equality is available in Appendix A.3. The result in Eqn. (11) indicates that to minimize the loss function estimated on the data set with MetaMix is the same as to minimize the loss function estimated on clear data. Therefore, MetaMix strategy could recover the unbiased model, which improves the meta-generalization performance in regression problem.
One influential line of meta-learning algorithms is learning a transferable metric space between samples from previous tasks [27, 30, 35, 41, 48]
, which classify
samples via lazy learning with the learned distance metric (e.g., Euclidean distance [35], cosine distance [41]). However, their applications are limited to classification problems, being infeasible in other problems (e.g., regression). In this work, we focus on gradient-based meta-learning algorithms learn a well-generalized model initialization from meta-training tasks [8, 7, 9, 10, 12, 20, 21, 31, 33], being agnostic to problems. This initialization is adapted to each task via the support set, and in turn the initialization is updated by maximizing the generalization performance on the query set. These approaches are at high risk of overfitting the meta-training tasks and generalizing poorly to meta-testing tasks.Common techniques increase the generalization capability via regularizations such as weight decay [17], dropout [11, 36], and incorporating noise [2, 3, 39]. As mentioned in the ending paragraph of Section 3, the adapted model by only a few steps in the support set in the inner-loop likely performs poorly on the query set. To improve such generalization for better adaptation, either the number of parameters to adapt is reduced [32, 52] or adpative noise is added [19]. The contribution of addressing this inner-loop overfitting towards meta-regularization, though positive, is limited. Until very recently, two regularizers were proposed to specifically improve meta-generalization, including MR-MAML [47] which regularizes the search space of the initialization while meanwhile allows it to be sufficiently adapted in the inner-loop, and TAML [15] enforcing the initialization to behave similarly across tasks. Instead of imposing regularizers on the initialization, our work takes a completely novel direction by actively soliciting more data to meta-train the initialization. Note that MetaMix is more than just a simple application of conventional data augmentation strategies [5, 40, 50], which has been proved in both [19] and our experiments to have a very limited role. We initiate to involve more data in the outer-loop and to identify the indispensable role of support sets.
To show the effectiveness of MetaMix, we conduct comprehensive experiments on three meta-learning problems, namely: (1) drug activity prediction, (2) pose prediction, and (3) image classification. We apply MetaMix on four gradient-based meta-learning algorithms, including MAML [7], MetaSGD [21], T-Net [20], and ANIL [32]. For comparison, we consider the following regularizers: Weight Decay as the traditional regularizer, CAVIA [52] and Meta-dropout [19] which regularizes the inner-loop, and MR-MAML [47] and TAML [15] both of which handle meta-generalization.
We solve a real-world application of drug activity prediction [26] where there are 4,276 target assays (i.e., tasks) each of which consists of a few drug compounds with tested activities against the target. We randomly selected 100 assays for meta-testing, 76 for meta-validation and the rest for meta-training. We repeat the random process four times and construct four groups of meta-testing assays for evaluation. Following [26], we evaluate the square of Pearson coefficient between the predicted and the groundtruth of all query samples for each -the task, and report the mean and median values over all meta-testing assays as well as the number of assays with which is deemed as an indicator of reliability in pharmacology. We use a base model of two fully connected layers with 500 hidden units. In , we set . More details on the dataset and experimental settings are discussed in Appendix C.1.
In practice, we notice that only updating the final layer in the inner-loop achieves the best performance, which is equivalent to ANIL. Thus, we apply this inner-loop update strategy to all baselines. For stability, here we also use ANIL++ [4] which stabilizes ANIL. In Table 2, we compare MetaMix with the baselines on the four drug evaluation groups. We observe that MetaMix consistently improves the performance despite of the backbone meta-learning algorithms (i.e., ANIL, ANIL++, MetaSGD, T-Net) in all scenarios. In addition, ANIL-MetaMix outperforms other anti-overfitting strategies. The consistent superior performance, even significantly better than the state-of-the-art pQSAR-max for this dataset, demonstrates that (1) MetaMix is compatible with existing meta-learning algorithms; (2) MetaMix is capable of improving the meta-generalization ability. Furthermore, similar to Figure 3, we illustrate the predictive performance of the learned initialization after applying MetaMix in Figure 4 – MetaMix resolves the inconsistency and mitigates the overfitting issue on the query set, which empowers the meta-generalization capability enhanced.
To further investigate where the improvement stems from, we adopt six different mixup strategies for meta-training. The results of Group 3 and 4 are reported in Table 3 (see Appendix D.1 for results of Group 1 and 2). We use Mixup(, ) to denote the mixup of data and (e.g., Mixup(, ) in our case). represents the concatenation of and . The fact that MetaMix enjoys better performance than Mixup() suggests that MetaMix is much more than simple data augmentation, by addressing the inconsistency of the learned initialization across the support and query sets and thereby improving meta-generalization. In addition, involving the support set only is insufficient for meta-generalization due to its relative small gradient norm, which is further verified by the unsatisfactory performance of .
Following [47], we use the multitask regression dataset created from Pascal 3D data [43], where an grey-scale image is used as input and the orientation relative to a fixed pose labels each image. 50 and 15 objects are randomly selected for meta-training and meta-testing, respectively. The base model consists of a convolutional encoder and a decoder with four convolutional blocks. We set in (see Appendix C.2 for detailed settings).
Model | 10-shot | 15-shot |
Weight Decay | ||
CAVIA | ||
Meta-dropout | ||
MR-MAML | ||
TAML | ||
ANIL | ||
MAML | ||
MetaSGD | ||
T-Net | ||
ANIL-MetaMix | ||
MAML-MetaMix | ||
MetaSGD-MetaMix | ||
T-Net-MetaMix |
Table 4
shows the performance (averaged MSE with 95% confidence interval) of baselines and MetaMix under 10-shot and 15-shot scenarios.
The inner-loop regularizers are not as effective as MR-MAML and TAML in improving meta-generalization; MAML-MetaMix significantly improves MR-MAML, suggesting the effectiveness of bringing more data in than imposing meta-regularizers only. We also investigate the influence of mixup strategies on pose prediction in Appendix E.2, which again advocates the effectiveness of the proposed mixup strategy in recovering the true and unbiased model, as our theoretic analyses suggest. In Appendix E.1, we investigate the influence ofdifferent hyperparameter settings (e.g.,
in ), and demonstrate the robustness of MetaMix against different settings.For image classification problems, standard benchmarks (e.g., Omniglot [18] and MiniImagenet [41]) are considered as mutually-exclusive tasks by introducing the shuffling mechanism of labels, which significantly alleviates the meta-overfitting issue [47]. To show the power of MetaMix, following [47], we adopt the non-mutually-exclusive setting for each image classification benchmark: each class with its classification label remains unchanged across different meta-training tasks during meta-training. Besides, we investigate image classification for heterogeneous tasks. We use the multi-dataset in [45] which consists of four subdatasets, i.e., Bird, Texture, Aircraft, and Fungi. The non-mutually-exclusive setting is also applied to this multi-dataset. Three representative heterogeneous meta-learning algorithms (i.e., MMAML [42], HSML [45], ARML [46]) are taken as baselines and applied with MetaMix. For each task, the classical N-way, K-shot setting is
used to evaluate the performance. We use the standard four-block convolutional neural network as
the base model. We set for all datasets. Detailed descriptions of experiment settings and hyperparameters are discussed in Appendix C.3.Model | Omniglot | MiniImagenet | ||
20-way 1-shot | 20-way 5-shot | 5-way 1-shot | 5-way 5-shot | |
Weight Decay | ||||
CAVIA | ||||
MR-MAML | ||||
Meta-dropout | ||||
TAML | ||||
MAML | ||||
MetaSGD | ||||
T-Net | ||||
ANIL | ||||
MAML-MetaMix | ||||
MetaSGD-MetaMix | ||||
T-Net-MetaMix | ||||
ANIL-MetaMix |
Model | 5-way 1-shot | 5-way 5-shot | ||||||
Bird | Texture | Aircraft | Fungi | Bird | Texture | Aircraft | Fungi | |
MMAML | ||||||||
HSML | ||||||||
ARML | ||||||||
MMAML-MetaMix | ||||||||
HSML-MetaMix | ||||||||
ARML-MetaMix |
In Table 5 and Table 6, we report the performance (accuracy with 95% confidence interval) on homogeneous datasets (i.e., Omniglot, MiniImagenet) and heterogeneous datasets, respectively. Aligned with other problems, in all non-mutually-exclusive datasets, applying the MetaMix strategy consistently improves existing meta-learning algorithms. For example, MAML-MetaMix significantly boosts MAML and most importantly outperforms MR-MAML, substantiating the effectiveness of MetaMix in improving the meta-generalization ability. It is worth mentioning that we also conduct the experiments on the standard mutually-exclusive setting of MiniImagenet in Appendix F.2. Though the label shuffling has significantly mitigated meta-overfitting, applying MetaMix still improves the meta-generalization to some extent. By varying mixup strategies in image classification of MiniImagenet in Table 7 (results of Omniglot are reported in Appendix F.3), we again corroborate our theoretic analysis that the knowledge distillation across support and query sets explains why MetaMix works. Besides, under the MiniImagenet 5-shot scenario, we investigate the influence of different hyperparameters, including sampling from the Beta distribution with different values of and , varying different fixed values of , and adjusting the layer to mixup (i.e., in Eqn. (4)) in Appendix F.4. All these studies indicate the robustness of MetaMix against hyperparameter settings.
In Figure 5 and Appendix F.5, we visualize the decision boundaries of MAML and MAML-MetaMix under the MiniImagenet 5-shot setting following [19]. We randomly select two classes for a meta-testing task and depict a binary decision boundary for the last layer of hidden representations. The figures show that the mixed samples by MetaMix do bridge the gap between support and query samples, and push the representations to be more compact (as we hypothesized in the end of Section 3) and the decision boundary to generalize better.
Current gradient-based meta-learning algorithms are at high risk of overfitting on meta-training tasks but poorly generalizing to meta-testing tasks. To address this issue, we propose a novel MetaMix strategy, which actively involves the support in the outer-loop optimization process. MetaMix linearly combines the input and associated hidden representations of support and target sets. We theoretically demonstrate that MetaMix can improve the meta-generalization capability. The state-of-the-art results in three different real-world datasets demonstrate the effectiveness and compatibility of MetaMix.
In this section, we give the detailed proof of Section 4 in the original paper.
We apply first and second-order Taylor expansion of around , and obtain the result as follows:
[leftmargin=*]
First-order approximation:
(12) |
Second-order approximation:
(13) |
Recall the bi-level optimization process of MAML as:
(14) |
By approximating the gradient of the loss in Eqn. (14) w.r.t. using Eqns. (12) and (13) , we have
(15) | ||||
(16) | ||||
(17) |
where Eqn. (15) applies second-order approximation (i.e., Eqn. (13)); Eqn. (16) uses Eqn. (12) and the fact that ; the last Eqn. (17) comes from the fact that and .
In this part, we provide the full proof of Eqn. (11). Besides, we conduct an empirical study on the this problem via a simple one-dimension regression task. In this toy experiment, we set the groundtruth regression function as , and generate the support set samples by randomly sampling and setting where is some noise. The query set is generated in the same way, except that . The results in Figure 6 demonstrate that without MetaMix, the original one-step MAML simply overfits the query samples, which verifies our claim in Section 1. When integrating one-step MAML with MetaMix, the fitting curve nearly matches the curve of the groundtruth function, which demonstrates the effectiveness of MetaMix in improving the generalization.
The full proof of Eqn. (11) is detailed as follows.
Given and , we can derive that
(21) | ||||
(22) | ||||
(23) | ||||
(24) | ||||
(25) |
where Eqn. (21) is presented with the expectation of a Bernoulli variable; Eqn. (22) is rewritten by introducing the Bayes’s rule
and applying the fact that the Beta distribution is the conjugate prior of the Bernoulli distribution; Eqn. (
23) is based on the fact that does not exist in the equation and thus it does not affect the value of the equation; Eqn. (24) is updated with the expectation of a Bernoulli variable; Eqn. (25) comes from the fact .For the distribution, we illustrate both symmetric () and skewed (i.e., ) scenarios in Figure 7.