IJCAI 2019 : Prototype Propagation Networks (PPN) for Weakly-supervised Few-shot Learning on Category Graph
A variety of machine learning applications expect to achieve rapid learning from a limited number of labeled data. However, the success of most current models is the result of heavy training on big data. Meta-learning addresses this problem by extracting common knowledge across different tasks that can be quickly adapted to new tasks. However, they do not fully explore weakly-supervised information, which is usually free or cheap to collect. In this paper, we show that weakly-labeled data can significantly improve the performance of meta-learning on few-shot classification. We propose prototype propagation network (PPN) trained on few-shot tasks together with data annotated by coarse-label. Given a category graph of the targeted fine-classes and some weakly-labeled coarse-classes, PPN learns an attention mechanism which propagates the prototype of one class to another on the graph, so that the K-nearest neighbor (KNN) classifier defined on the propagated prototypes results in high accuracy across different few-shot tasks. The training tasks are generated by subgraph sampling, and the training objective is obtained by accumulating the level-wise classification loss on the subgraph. The resulting graph of prototypes can be continually re-used and updated for new tasks and classes. We also introduce two practical test/inference settings which differ according to whether the test task can leverage any weakly-supervised information as in training. On two benchmarks, PPN significantly outperforms most recent few-shot learning methods in different settings, even when they are also allowed to train on weakly-labeled data.READ FULL TEXT VIEW PDF
IJCAI 2019 : Prototype Propagation Networks (PPN) for Weakly-supervised Few-shot Learning on Category Graph
Machine learning (ML) has achieved breakthrough success in a great number of application fields during the past 10 years, due to more expressive model structures, the availability of massive training data, and fast upgrading of computational hardware/infrastructure. Nowadays, with the support of expensive hardware, we can train super-powerful deep neural networks containing thousands of layers on millions or even trillions of data within an acceptable time. However, as AI becomes democratized for personal or small business use, with concerns about data privacy, demand is rapidly growing for instant learning of highly customized models on edge/mobile devices with limited data. This brings new challenges since the big data and computational power that major ML techniques rely on are no longer available or affordable. In such cases, ML systems that can quickly adapt to new tasks and produce reliable models by only seeing few-shot training data are highly preferable.
This few-shot learning problem can be addressed by a class of approaches called “meta-learning”. Instead of independently learning each task from scratch, the goal of meta-learning is to learn the common knowledge shared across different tasks, or “learning to learn”. The knowledge is at learning/algorithm-level and is task-independent, and thus can be applied to new unseen tasks. For example, it can be shared initialization weights [Finn et al.2017], an optimization algorithm [Ravi and Larochelle2017], a distance/similarity metric [Vinyals et al.2016], or a generator of prototypes [Snell et al.2017] that compose the support set of the K-nearest neighbor (KNN) predictor. Therefore, new tasks can benefit from the accumulated meta-knowledge extracted from previous tasks. In contrast to single-task learning, the “training data” in meta-learning are tasks, i.e., tasks from a certain distribution are sampled across all possible tasks. It then tries to maximize the validation accuracy of these sampled tasks. Meta-learning shares ideas with life-long/continual/progressive learning in that the meta-knowledge can be re-used and updated for future tasks. It generalizes multi-task learning [Caruana1997] since it can be applied to any new task drawn from the same distribution.
Although recent studies of meta-learning have shown its effectiveness on few-shot learning tasks, most of them do not leverage weakly-supervised information, which is usually free or cheap to collect, and has been proved to be helpful when training data is insufficient, e.g., in weakly-supervised [Zhou2017]Belkin et al.2006, Zhu and Ghahramani2002]. In this paper, we show that weakly-supervised information can significantly improve the performance of meta-learning on few-shot classification. In particular, we leverage weakly-labeled data that are annotated by coarse-class labels, e.g., an image of a Tractor with a coarse label of Machine. These data are usually cheap and easy to collect from web tags or human annotators. We additionally assume that a category graph describing the relationship between fine classes and coarse classes is available, where each node represents a class and each directed edge connects a fine-class to its parent coarse-class. An example of a category graph is given in Figure 2
. It is not necessary for the graph to cover all the possible relationships and graph containing partial relationship is usually easy to obtain, e.g., the ImageNet category tree built based on the synsets from WordNet.
We propose a meta-learning model called prototype propagation network (PPN) to explore the above weakly-supervised information for few-shot classification tasks. PPN produces a prototype per class by propagating the prototype of each class to its child classes on the category graph, where an attention mechanism generates the edge weights used for propagation. The learning goal of PPN is to minimize the validation errors of a KNN classifier built on the propagated prototypes for few-shot classification tasks. The prototype propagation on the graph enables the classification error of data belonging to any class to be back-propagated to the prototype of another class, if there exists a path between the two classes. The classification error on weakly-labeled data can thus be back-propagated to improve the prototypes of other classes, which later contribute to the prototype of few-shot classes via prototype propagation. Therefore, the weakly-labeled data and coarse classes can directly contribute to the update of prototypes of few-shot fine classes, offsetting the lack of training data. The resulting graph of prototypes can be repeatedly used, updated and augmented on new tasks/classes as an episodic memory. Interestingly, this interplay between the prototype graph and the quick adaptation to new few-shot tasks (via KNN) is analogous to the complementary learning system that reconciles episodic memory with statistical learning inside the human hippocampus [Schapiro et al.2017], which is believed to be critical for rapid learning.
Similar to other meta-learning methods, PPN learns from the training processes of different few-shot tasks, each defined on a subset of classes sampled from the category graph. To fully explore the weakly-labeled data, we develop a level-wise method to train tasks, generated by subgraph sampling, for both coarse and fine classes on the graph. In addition, we introduce two testing/inference settings that are common in different practical application scenarios: one (PPN+) is allowed to use weakly-labeled data in testing tasks and is given the edges connecting test classes to the category graph, while the other (PPN) cannot access any extra information except for the few-shot training data of the new tasks. In experiments, we extracted two benchmark datasets from ILSVRC-12 (ImageNet) [Deng et al.2009], specifically for weakly-supervised few-shot learning. In different test settings, our method consistently and significantly outperforms the three most recently proposed few-shot learning models and their variants, which also trained on weakly-labeled data. The prototypes learned by PPN is visualized in Figure 1.
A great number of meta-learning approaches have been proposed to address the few-shot learning problem. There are usually two main ideas behind these works: 1) learning a component of a KNN predictor applied to all the tasks, e.g., the support set [Snell et al.2017], the distance metric [Vinyals et al.2016], or both [Mishra et al.2018]; 2) learning a component of an optimization algorithm used to train different tasks, e.g., an initialization point [Finn et al.2017]. Another straightforward approach is to generate more training data for few-shot tasks by a data augmentation technique or generative model [Lake et al.2015, Wong and Yuille2015]. Our method follows the idea of learning a small support set (prototypes) for KNN, and differs in that we leverage the weakly-labeled data by relating prototypes of different classes.
Auxiliary information: unlabeled data [Ren et al.2018] and inter/intra-task relationship [Nichol and Schulman2018, Liu et al.2019, Ravi and Beatson2019] have recently been used to improve the performance of few-shot learning. Co-training on auxiliary tasks [Oreshkin et al.2018] has also been applied to improve the learning of the similarity metric. In contrast, to the best of our knowledge, we are the first to utilize the weakly-labeled data as the auxiliary information to bring significant improvement.
The prototype propagation in our method inherits ideas from random walk, message passing, belief propagation, and label propagation [Wu et al.2019, Zhou et al.2017, Dong and Yang2019]. A similar idea has also been used in more recent graph neural networks (GNN) such as graph attention networks (GAT) [Veličković et al.2018]. GNN are mainly designed for tasks on graph-structured data such as node classification [Hamilton et al.2017], graph embedding [Yu et al.2018], graph clustering [Wang et al.2019] and graph generation [Dai et al.2018]. Although our method uses an attention mechanism similar to GAT for propagation, we have a different training scheme (Algorithm 1) that only requires one-step propagation on a specific directed acyclic graph (DAG). GNN has been applied to meta-learning in [Garcia and Bruna2018], but the associated graph structure is defined on samples (images) instead of classes/prototypes, and is handcrafted with fully connected edges.
|Data for targeted fine-classes|
|Labels for targeted fine-classes|
|Data for weakly-labeled classes|
|Labels for weakly labeled classes|
|Category graph (a directed acyclic graph)|
|Level- of (a subgraph of )|
|Distribution of data from classes on subgraph|
|Set of classes on subgraph|
|Set of classes from level- of subgraph|
|a few-shot learning task|
|Distribution that task is drawn from|
|Distribution of data from classes in task|
|a meta-learner producing learner models|
|Parameters of a meta learner|
|Final prototype for class|
|Initialized Prototype for class|
|Propagated Prototype for class|
In weakly-supervised few-shot learning, we learn from two types of data: the few-shot data annotated by the target fine-class labels and the weakly-labeled data annotated by coarse-class labels. Each is associated with a fine-class label , while each is associated with a coarse-class label . We assume that there exists a category graph describing the relationship between fine classes and coarse classes. This is a directed acyclic graph (DAG) , where each node denotes a class, and each edge (or arc) connects a parent class to one of its child classes . An example of the category graph is given in Figure 2: the few-shot classes are the leaves of the graph, while the weakly-labeled classes are the parents 111Parents: directly connected coarse classes. or ancestors 222Ancestors: coarse classes not connected but linked via paths. of the few-shot classes; for example, the parent class of “Tractor” is “Farm Machine”, while its ancestors include “Farm Machine” and “Machine”. A child class can belong to multiple parent classes, e.g., the class “Organ” has two parent classes, “Wind Instrument” and “Keyboard Instrument”.
We follow the setting of few-shot learning, which draws training tasks from a task distribution and assumes that the test tasks are also drawn from the same distribution. For few-shot classification, each task is defined by a subset of classes, e.g., an -way -shot task refers to classification over classes and each class only has training samples. It is necessary to sample the training classes and test classes from two disjoint sets of classes to avoid overlapping. In our problem, as shown by Figure 2, the few-shot classes used for training and test (colored light blue and deep blue respectively) are also non-overlapping, but we allow them to share some parents on the graph. We also allow training tasks to cover any classes on the graph. Since finer-class labels can provide more information about the targeted few-shot classes but are more expensive to obtain, we assume that the amount of weakly-labeled data is reduced exponentially when the class becomes finer. The training aims to solve the following risk minimization (or likelihood maximization) of “learning to learn”:
where each task is defined by a subset of classes , is the distribution of data-label pair with , is the likelihood of produced by model for task , where is the meta-learner parameter shared by all the tasks drawn from .
In our method, is computed by a soft-version of KNN, where is the set of nearest neighbor candidates (i.e., the support set of KNN) for task , and defines the mechanism generating the support set. Similar to prototypical networks [Snell et al.2017], the support set is composed of prototypes , each of which is associated with a class . Given a data point , we first compute its representation , where
is convolutional neural networks (CNN) with parameter, then is computed as:
In the following, we will introduce prototype propagation which generates for any task . An overall of all procedures in our model is shown in Figure 3.
In PPN, each training task is a level-wise classification on a sampled subgraph , i.e., a classification task over , where denotes the level- of subgraph and is the set of all the classes on .
The prototype propagation is defined on each subgraph , which covers classes . Given the associated training data for class , the prototype of the class is initialized as the average of the representations of the samples , i.e.,
For each parent class of class on , we propagate to class with edge weight measuring the similarity between class and , and aggregate the propagation (the messages) from all the parent classes by
where denotes the set of all parent classes of on the subgraph , and the edge weight is a learnable similarity metric defined by a dot-product attention [Vaswani et al.2017], i.e.,
where and are learnable transformations applied to prototypes with parameters
, e.g., linear transformationsand . The prototype after propagation is a weighted average of and with weight , i.e.,
For each classification task on subgraph , is used in Eq. (2
) to produce the likelihood probability. The likelihood maximization in Eq. (1) aims to optimize the parameters from and from the attention mechanism across all the training tasks.
The goal of meta-training is to learn a parameterized propagation mechanism defined by Eq. (3)-Eq. (6) on few-shot tasks. In each iteration, we randomly sample a subset of few-shot classes, which together with all their parent classes and edges on form a subgraph . A training task is drawn from each level as the classification over classes . The meta-learning in Eq. (1) is approximated by
where is the data distribution of data-label pair from classes . Since the prototype propagation is defined on the whole subgraph, it generates a computational graph relating each class to all of its parent classes. Hence, during training, the classification error on each class is back-propagated to the prototypes of its parent classes, which will be updated to improve the validation accuracy of finer classes and propagated to generate the prototypes of the few-shot classes later. Therefore, the weakly-labeled data of a coarse class will contribute to the few-shot learning tasks on the leaf-node classes.
The complete level-wise training procedure of PPN is given in Algorithm 1, each of whose iterations comprises two main stages: prototype propagation (lines 9-11) which builds a computational graph over prototypes of the classes on a sampled subgraph , and level-wise training (lines 12-16) which updates the parameters and
on per-level classification task. In the prototype propagation stage, the prototype of each class is merged with the information from the prototypes of its parent classes, and the classification error using the merged prototypes will be backpropagated to update the parent prototypes during the level-wise training stage. To improve the computational efficiency, we do not updatefor every propagation. Instead, we lazily update for all classes every epochs, as shown in lines 3-7.
We study two test settings for weakly-supervised few-shot learning, both of which will be used in the evaluation of PPN in our experiments. They differ in whether or not the weakly-supervised information, i.e., the weakly-labeled data and the connections of new classes to the category graph, is still accessible in the test tasks. The first setting (PPN+) is allowed to access this information within test tasks while the second setting (PPN) is unable to access the information. In other words, for PPN+, the edges (i.e., the yellow arrows in Figure 3) between each test class and any other classes are known during test phase, while these edges are unknown and needed to be predicted in the PPN setting. The second setting is more challenging but is preferred in a variety of applications, for example, where the test tasks happen on different devices, whereas the first setting is more appropriate for life-long/continual learning on a single machine.
In the second setting, we can still leverage the prototypes achieved in the training phase and use them for the propagation of prototypes for test classes. In particular, for an unseen test class (and its associated samples ), we find the -nearest neighbors of among all the prototypes achieved during training, and treat the training classes associated with the -nearest neighbor prototypes as the parents of on . These parent-class prototypes serve to provide weakly-supervised information to the test classes via propagation on the category graph.
In both settings, for each class in a test task , we apply the prototype propagation in Eq. (3)-(6) on a subgraph composed of and its parents . This produces the final prototype , which is one of the candidates for the nearest neighbors for in KNN classification on task . In the first setting (PPN+), the hierarchy covering both the training and test classes is known so the parent class of each test class may come from either the training classes or the weakly-labeled test classes. When a parent class is among the training classes, we use the buffered prototype from training for propagation; otherwise, we use Eq. (3) to compute over weakly-labeled samples belonging to class for this task. In the second setting (PPN), since the edges connecting test classes are unknown and are predicted by KNN as introduced in the last paragraph, we assume that all the parents of are from training classes. We directly use the parents’ buffered prototypes from training for propagation.
|Prototypical Net [Snell et al.2017]||N||33.171.65%||46.760.98%||20.480.99%||31.490.57%|
|GNN [Garcia and Bruna2018]||N||30.830.66%||41.330.62%||20.330.60%||22.500.62%|
|Closer Look [Chen et al.2019]||N||32.271.58%||46.020.74%||22.780.94%||28.040.36%|
|Prototypical Net [Snell et al.2017]||N||31.931.62%||49.800.90%||21.020.97%||36.420.62%|
|GNN [Garcia and Bruna2018]||N||33.600.11%||45.870.12%||22.000.89%||34.330.75%|
|Closer Look [Chen et al.2019]||N||33.101.57%||40.670.73%||20.850.92%||35.190.43%|
We compare PPN/PPN+ to three baseline methods, i.e., Prototypical Networks, GNN and Closer Look [Chen et al.2019]
, and their variants of using the same weakly-labeled data as PPN/PPN+. For their variants, we apply the same level-wise training on the same weakly-labeled data as in PPN/PPN+ to the original implementations, i.e., we replace the original training tasks with level-wise training tasks. We always tune the initial learning rate, the schedule of learning rate, and other hyperparameters of all the baselines and their variants on a validation set of tasks. The results are reported in Table3 and Table 4, where the variants of baselines are marked by “*” following the baseline name.
In PPN/PPN+ and all the baseline methods (as well as their variants), we use the same backbone CNN (i.e., ) that has been used in most previous few-shot learning works [Snell et al.2017, Finn et al.2017, Vinyals et al.2016]. It has 4 convolutional layers, each with filters of 3
3 convolution, followed by batch normalization[Ioffe and Szegedy2015]
, ReLU nonlinearity, and 2
2 max-pooling. The transformationand in the attention module are fully connected linear layers.
In PPN/PPN+, the variance ofincreases when the number of samples (i.e., the “shot”) per class reduces. Hence, we set in Eq. (6) for -way -shot classifications, and for -way -shot classification. During training, we lazily update for all the classes on the graph every epochs and choose the nearest neighbours as parents among all prototypes gained after training for PPN. Adam [Kingma and Ba2015] is used to train the model for 150k iterations, with an initial learning rate of , a weight decay of , and a momentum of . We reduce the learning rate by a factor of every 15k iterations starting from the 10k-th iterations.
WS-ImageNet-Pure is a subset of ILSVRC-12. On the ImageNet WordNet Hierarchy, we extract classes from level- as leaf nodes of the category graph and use them as the targeted classes in few-shot tasks. The ancestor nodes of these classes on are then sampled from level- to level-, which compose weakly-labeled classes . We sub-sample the data points for classes on in a bottom-up manner: for any level- (bottom) level class , we directly sample a subset from all the images belonging to in ImageNet; for any class on lower level- with , we sample from all the images that belong to and have not been sampled into its descendant classes. Hence, we know that any data point sampled into class belongs to all the ancestor classes of but we do not know its labels on any targeted class of . In addition, we sample each candidate data point for any class on level with probability . Hence, the number of data points associated with each class thus reduces exponentially when the level of the class increases. This is consistent with many practical scenarios, i.e., samples with finer-class labels can provide more information about targeted few-shot tasks, but they are much more expensive to obtain and usually insufficient.
For training-test splitting of few-shot classes333Each training task is classification defined on a randomly sampled subset of training classes, while each test task is classification defined on a randomly sampled subset of test classes., we divide the classes from level- into two disjoint subsets with ratio : ( for training and for test). This ensures that any class in any test task has never been learned in training tasks. However, we allow training classes and test classes to share some parent classes. The detailed statistics of WS-ImageNet-Pure are given in Table 3.
|PPN||PPN+||Proto Net||GNN||Closer Look|
The experimental results of PPN/PPN+ and all the baselines (and their weakly-supervised variants ending with “*”) on WS-ImageNet-Pure are shown in Table 3. PPN/PPN+ outperform all other methods. The table shows that PPN/PPN+ are more advantageous in 1-shot tasks, and that PPN+ achieves improvement compared to other methods. This implies that the weakly-supervised information can be more helpful when supervised data is highly insufficient, and our method is able to significantly boost performance by exploring the weakly-labeled data. Although all the weakly-supervised variants of baselines are trained on the same data as PPN/PPN+, they do not achieve similar improvement because their model structures do not have such mechanisms as the prototype propagation in PPN which relates different classes and tasks. In addition, training on unrelated tasks can be distracting and even detrimental to performance. In contrast, PPN/PPN+ build a computational graph of prototypes associated with both coarse and fine classes, and the error on any class can be used to update the prototypes of other classes via backpropagation on the computational graph.
To verify if PPN can still learn from weakly-labeled data that belong to other fine classes not involved in the few-shot tasks, we propose another subset of ILSVRC-12, WS-ImageNet-Mix, whose detailed statistics are given in Table LABEL:tab:datasets. We extract WS-ImageNet-Mix by following the same procedure as extracting WS-ImageNet-Pure except that: 1) data points sampled for a coarse (non-leaf) class can belong to the remaining level- classes outside of the level- classes used for generating few-shot tasks; and 2) for any class on level-, we sample each data point with probability instead of .
The experimental results are reported in Table 4, which shows that PPN/PPN+ still outperform all baselines, and PPN+ outperforms them by for 1-shot classification. This indicates that PPN/PPN+ is robust to (and might be able to leverage) weakly-labeled data unrelated to the final few-shot tasks.
To study whether and how the propagation in Eq. (4) improves few-shot learning, we evaluate PPN using different in Eq. (6). Specifically, we try different weights (x-axis in Figure 4) for in Eq. (6) between , and report the validation accuracy (y-axis in Figure 4) on test tasks for way1shot tasks and the two datasets. In all scenarios, increasing the weight of in the given range consistently improves the accuracy (although the accuracy might drop if the weight gets too close to though), which demonstrates the effectiveness of prototype propagation.
The average per-iteration time (in seconds) on a single TITAN XP for PPN/PPN+ and the baselines are listed as Table 5. PPN has moderate time cost comparing to other baselines. Compared to prototypical network, our propagation procedure only requires 10% extra time cost but significantly improves the performance.
In this work, we propose to explore weakly-labeled data by developing a propagation mechanism generating class prototypes to improve the performance of few-shot learning. Empirical studies verifies the advantage of exploring the weakly-labeled data and the effectiveness of the propagation mechanism. Our model can be extended to multi-steps propagation and assimilate more weakly-labeled information. The graph can be more general than a class hierarchy, and the node on the graph is not limited to be a class: it can be extended to any output attribute.
This research was funded by the Australian Government through the Australian Research Council (ARC) under grants 1) LP160100630 partnership with Australia Government Department of Health and 2) LP150100671 partnership with Australia Research Alliance for Children and Youth (ARACY) and Global Business College Australia (GBCA). We also acknowledge the support of NVIDIA Corporation and MakeMagic Australia with the donation of GPUs.
Syntax-directed variational autoencoder for structured data.In ICLR, 2018.