Model-Agnostic Graph Regularization for Few-Shot Learning

02/14/2021 ∙ by Ethan Shen, et al. ∙ Google University of Massachusetts Amherst Stanford University 15

In many domains, relationships between categories are encoded in the knowledge graph. Recently, promising results have been achieved by incorporating knowledge graph as side information in hard classification tasks with severely limited data. However, prior models consist of highly complex architectures with many sub-components that all seem to impact performance. In this paper, we present a comprehensive empirical study on graph embedded few-shot learning. We introduce a graph regularization approach that allows a deeper understanding of the impact of incorporating graph information between labels. Our proposed regularization is widely applicable and model-agnostic, and boosts the performance of any few-shot learning model, including fine-tuning, metric-based, and optimization-based meta-learning. Our approach improves the performance of strong base learners by up to 2 and 6.7 Additional analyses reveal that graph regularizing models result in a lower loss for more difficult tasks, such as those with fewer shots and less informative support examples.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Few-shot learning refers to the task of generalizing from a very few examples, an ability that humans have but machines lack. Recently, major breakthroughs have been achieved with meta-learning, which leverages prior experience from many related tasks to effectively learn to adapt to unseen tasks  schmidhuber87, bengio92. At a high level, meta-learning has been divided into metric-based approaches that learn a transferable metric across tasks  vinyals2016matching, snell2017prototypical, sung2018learning, and optimization-based approaches that learn initializations for fast adaptation on new tasks  finn2017model, rusu2018meta

. Beyond meta-learning, transfer learning by pretraining and fine-tuning on novel tasks has achieved surprisingly competitive performance on few-shot tasks  

chen2019closer, dhillon2019baseline, wang2019simpleshot.

In many domains, external knowledge about the class labels can be used. For example, this information is crucial in the zero-shot learning paradigm, which seeks to generalize to novel classes without seeing any training examples  kankuekul2012online, lampert2013attribute, xian2018zero. Prior knowledge often takes the form of a knowledge graph  wang2018zero, such as the WordNet hierarchy  miller1995wordnet

in computer vision tasks, or Gene Ontology  

ashburner2000gene in biology. In such cases, relationships between categories in the graph are used to transfer knowledge from base to novel classes. This idea dates back to hierarchical classification  koller1997hierarchically, salakhutdinov2011learning.

Recently, few-shot learning methods have been enhanced with graph information, achieving state-of-the-art performance on benchmark image classification tasks chen2019knowledge, liu2019prototype, liu2019learning, li2019large, suo2020tadanet. Proposed methods typically employ sophisticated and highly parameterized graph models on top of convolutional feature extractors. However, the complexity of these methods prevents deeper understanding of the impact of incorporating graph information. Furthermore, these models are inflexible and incompatible with other approaches in the rapidly-improving field of meta-learning, demonstrating the need for a model-agnostic graph augmentation method.

Here, we conduct a comprehensive empirical study of incorporating knowledge graph information into few-shot learning. First, we introduce a graph regularization approach for incorporating graph relationships between labels applicable to any few-shot learning method. Motivated by node embedding  grover2016node2vec and graph regularization principles  hallac2015network, our proposed regularization enforces category-level representations to preserve neighborhood similarities in a graph. By design, it allows us to directly measure benefits of enhancing few-shot learners with graph information. We incorporate our proposed regularization into three major approaches of few-shot learning: (i) metric-learning, represented by Prototypical Networks  snell2017prototypical, (ii) optimization-based learning, represented by LEO  rusu2018meta, and (iii) fine-tuning, represented by SGM  qiao2018few and  mangla2020charting. We demonstrate that graph regularization consistently improves each method and can be widely applied whenever category relations are available. Next, we compare our approach to state-of-the-art methods, including those that utilize the same category hierarchy on standard benchmark Mini-ImageNet and large-scale ImageNet-FS datasets. Remarkably, we find that our approach improves the performance of strong base learners by as much as and outperforms graph embedded baselines, even though it is simple, easy to tune, and introduces minimal additional parameters. Finally, we explore the behavior of incorporating graph information in controlled synthetic experiments. Our analysis shows that graph regularizing models yields better decision boundaries in lower-shot learning, and achieves significantly higher gains on more difficult few-shot episodes.

2 Model-Agnostic Graph Regularization

Our approach is a model-agnostic graph regularization objective based on the idea that the graph structure of class labels can guide learning of model parameters. The graph regularization objective ensures labels in the same graph neighborhood have similar parameters. The regularization is combined with a classification loss to form the overall objective. The classification loss is flexible and depends on the base learner. For instance, the classification loss can correspond to cross-entropy loss chen2019closer, or distance-based loss between example embeddings and class prototoypes snell2017prototypical.

2.1 Problem Setup

We assume that we are given a dataset defined as a pair of examples with corresponding labels . We say that point has the label . For each episode, we learn from a support set and evaluate on a held-out query set . For each dataset, we split all classes into and , . During evaluation, we sample the classes from a larger set of classes , and sample examples from each class. During training, we use a disjoint set of classes to train the model. Non-episodic training approaches treat

as a standard supervised learning problem, while episodic training approaches match the conditions on which the model is trained and evaluated by sampling episodes from

. More details on the problem setup can be found in Appendix A. Additionally, we assume that there exists side information about the labels in the form of a graph where is the set of all nodes in the label graph, and is the set of edges.

2.2 Regularization

We incorporate graph information using the random walk-based node2vec objective [grover2016node2vec]. Random walk methods for graph embedding [perozzi2014deepwalk]

are fit by maximizing the probability of predicting the neighborhoods for each target node in the graph. Node2vec performs biased random walks by introducing hyperparameters to balance between breadth-first search (BFS) and depth-first search (DFS) to capture local structures and global communities. We formulate the node2vec loss below:

(1)

where are node representations, is a similarity function between the nodes, is the set of neighbor nodes of node , is the temperature hyperparameter, and is partition function defined as . The partition function is approximated using negative sampling mikolov2013distributed. We obtain the neighborhood by performing a random walk starting from a source node . The similarity function depends on the base learner, which we outline in Section 2.3.

2.3 Augmentation Strategies

Our graph-regularization framework is model-agnostic and intuitively applicable to a wide variety of few-shot approaches. Here, we describe augmentation strategies for high-performing learners from metric-based meta-learning, optimization-based meta-learning and fine-tuning by formulating each as a joint learning objective.

2.3.1 Augmenting Metric-Based Models

Metric-based approaches learn an embedding function to compare query set examples. Prototypical networks are a high-performing learner of this class, especially when controlling for model complexity [chen2019closer, triantafillou2019meta]. Prototypical networks construct a prototype of the class by taking the mean of support set examples, and comparing query examples using Euclidean distance. We regularize these prototypes so they respect class similarities and get the joint objective:

(2)

We set the graph similarity function to negative Euclidean distance, . Note that our approach can easily be extended to other metric-based learners, for example regularizing the output of the relation module for Relation Networks [sung2018learning].

2.3.2 Augmenting Optimization-Based Models

Optimization-based meta-learners such as MAML [finn2017model] and LEO [rusu2018meta]

consist of two optimization loops: the outer loop updates the neural network parameters to an initialization that enables fast adaptation, while the inner loop performs a few gradient updates over the support set to adapt to the new task. Graph regularization enforces class similarities among parameters during inner-loop adaptation.

Specifically for LEO, we pass support set examples through an encoder to produce latent class encodings

, which are decoded to generate classifier parameters

. Given instantiated model parameters learned from the outer loop, gradient steps are taken in the latent space to get while freezing all other parameters to produce final adapted parameters . For more details, please refer to rusu2018meta. Concretely, we obtain the joint regularized objective below for the inner-loop adaptations:

(3)

We set the graph similarity function to the inner product,

, though in practice cosine similarity,

results in more stable learning.

2.3.3 Augmenting Fine-tuning Models

Recent approaches such as Baseline++ chen2019closer and mangla2020charting have demonstrated remarkable performance by pre-training a model on the training set, and fine-tuning the classifier parameters on the support set of each task. We follow chen2019closer and freeze the feature embedding model during fine-tuning, though the model can be fine-tuned as well dhillon2019baseline. We perform graph regularization on the classifiers in the last layer of the network, which are learned for novel classes during fine-tuning. This results in the objective below:

(4)

We set the graph similarity to cosine similarity, .

3 Experimental Results

For all ImageNet experiments, we use the associated WordNet miller1995wordnet category hierarchy to define graph relationships between classes. Details of the experimental setup are given in Appendix B. On the synthetic dataset, we analyze the effect of graph regularizing few-shot methods.

3.1 Mini-ImageNet Experiments

We compare performance to few-shot baselines and graph embedded approach KGTN chen2019knowledge on the Mini-ImageNet experiment. We enhance mangla2020charting, a strong baseline fine-tuning model. Table 1 shows graph regularization results on Mini-ImageNet compared to results of the state-of-the-art models. We find that enhanced with the proposed graph regularization outperforms all other methods on both 1- and 5-shot tasks.

As an additional baseline, we consider KGTN which also utilizes the WordNet hierarchy for better generalization. To ensure that our improvements are not caused by the embedding function, we pretrain KGTN feature extractor using . Even when controlling for improvements in the feature extractor, we find that our simple graph regularization method outperforms complex graph-embedded models.

Model Backbone 1-shot 5-shot
Qiao qiao2018few WRN 28-10 59.60 0.41 73.74 0.19
Baseline++ chen2019closer WRN 28-10 59.62 0.81 78.80 0.61
LEO (train+val) rusu2018meta WRN 28-10 61.76 0.08 77.59 0.12
ProtoNet snell2017prototypical WRN 28-10 62.60 0.20 79.97 0.14
MatchingNet vinyals2016matching WRN 28-10 64.03 0.20 76.32 0.16
mangla2020charting WRN 28-10 64.93 0.18 83.18 0.11
SimpleShot wang2019simpleshot WRN 28-10 65.87 0.20 82.09 0.14
KGTN chen2019knowledge WRN 28-10 65.71 0.75 81.07 0.50
+ Graph (Ours) WRN 28-10 66.93 0.65 83.35 0.53
Table 1: Results on -shot and -shot classification on the Mini-ImageNet dataset. We report average accuracy over 600 randomly sampled episodes. We show graph-based models in the bottom section.

3.2 Graph Regularization is Model-Agnostic

We augment ProtoNet snell2017prototypical, LEO rusu2018meta, and mangla2020charting approaches with graph regularization and evaluate effectiveness of our approach on the Mini-ImageNet dataset. These few-shot learning models are fundamentally different and vary in both optimization and training procedures. For example, ProtoNet and LEO are both trained episodically, while is trained non-episodically. However, the flexibility of our graph regularization loss allows us to easily extend each method. Table 2 shows the results of graph enhanced few-shot baselines. The results demonstrate that graph regularization consistently improves performance of few-shot baselines with larger gains in the -shot setup.

Model Backbone 1-shot 5-shot
ProtoNet snell2017prototypical ResNet-18 54.16 0.82 73.68 0.65
ProtoNet + Graph (Ours) ResNet-18 55.47 0.73 74.56 0.49
LEO (train) rusu2018meta WRN 28-10 58.22 0.09 74.46 0.19
LEO + Graph (Ours) WRN 28-10 60.93 0.19 76.33 0.17
mangla2020charting WRN 28-10 64.93 0.18 83.18 0.11
+ Graph (Ours) WRN 28-10 66.93 0.65 83.35 0.53
Table 2: Performance of graph-regularized few-shot baselines on the Mini-ImageNet dataset. We report average accuracy over 600 randomly sampled episodes.

3.3 Large-Scale Few-Shot Classification

We next evaluate our graph regularization approach on the large-scale ImageNet-FS dataset, which includes 1000 classes. Notably, this task is more challenging because it requires choosing among all novel classes, an arguably more realistic evaluation procedure. We sample K images per category, repeat the experiments times, and report mean accuracy with confidence intervals. Results demonstrate that our graph regularization method boosts performance of the SGM baseline hariharan2017low by as much as . Remarkably, augmenting SGM with graph regularization outperforms all few-shot baselines, as well as models that benefit from class semantic information and label hierarchy such as KTCH liu2019large and KGTN chen2019knowledge. We include further experimental details in Appendix B, and explore further ablations to justify design choices in Appendix C.

Model Backbone 1-shot 2-shot 5-shot
SGM hariharan2017low ResNet-50 54.3 67.0 77.4
MatchingNet vinyals2016matching ResNet-50 53.5 63.5 72.7
ProtoNet snell2017prototypical ResNet-50 49.6 64.0 74.4
PMN wang2018low ResNet-50 53.3 65.2 75.9
KTCH liu2019large ResNet-50 58.1 67.3 77.6
KGTN chen2019knowledge ResNet-50 60.1 69.4 78.1
SGM + Graph (Ours) ResNet-50 61.09 0.37 70.35 0.17 78.61 0.19
Table 3: Top-5 accuracy on the novel categories for the Imagenet-FS dataset. KTCH and KGTN are graph-based models. We report confidence intervals where provided. The confidence intervals for hariharan2017low, vinyals2016matching, snell2017prototypical, wang2018low are on the order of .

3.4 Experiments on Synthetic Dataset

To analyze the benefits of graph regularization, we devise a few-shot classification problem on a synthetic dataset. We first embed a balanced binary tree of height in -dimensions using node2vec grover2016node2vec. We set all leaf nodes as classes, and assign half as base and half as novel. For each task, we sample support and

query examples from a Gaussian with mean centered at each class embedding and standard deviation

. Given support examples, the task is to predict the correct class for query examples among novel classes. In these experiments, we set , , , , and . The baseline model is a linear classifier layer with cross-entropy loss, and we apply graph regularization to this baseline. We learn using SGD with learning rate 0.1 for 100 iterations.

We first visualize the learned decision boundaries on identical tasks with and without graph regularization in Figure 1. In this task, the sampled support examples are far away from the query examples, particularly for the purple and green classes. The baseline model learns poor decision boundaries, resulting in many misclassified query examples. In contrast, much fewer query examples are misclassified when graph regularization is applied. Intuitively, graph regularization helps more when the support set is further away from the sampled data points, and thus generalization is harder.

Figure 1: Synthetic experiment results. PCA visualization of learned classifiers for a single task without (left) and with graph regularization (right). Support examples are squares, query examples are dots, learned classifiers are crosses. Shaded regions show decision boundaries.

To measure the relationship between few-shot task difficulty and performance, we adopt the hardness metric proposed in dhillon2019baseline. Intuitively, few-shot task hardness depends on the relative location of labeled and unlabeled examples. If labeled examples are close to the unlabeled examples of the same class, then learned classifiers will result in good decision boundaries and consequently accuracy will be high. Given a support set and query set , the hardness

is defined as the average log-odds of a query example being classified incorrectly:

(5)

where is a softmax distribution over , the similarity scores between query examples and the means of the support examples from the class in .

We show average loss with shaded 95% confidence intervals across shots in Figure 2 (left), confirming our observations in real-world datasets that graph regularization improves the baseline model the most for tasks with lower shots. Furthermore, using our synthetic dataset, we artificially create more difficult few-shot tasks by increasing , tree heights, and increasing , the spread of sampled examples. We plot loss with respect to the proposed hardness metric of each task in Figure 2 (right). The results demonstrate that graph regularization achieves higher performance gains on more difficult tasks.

Figure 2: Quantified results of classification loss across shots (left) and task hardness metric (right). Each point is a sampled task. Red color denotes graph regularized method and gray method without graph regularization.

4 Conclusion

We have introduced a graph regularization method for incorporating label graph side-information into few-shot learning. Our approach is simple and effective, model-agnostic and boosts performance of a wide range of few-shot learners. We further showed that introduced graph regularization outperforms more complex state-of-the-art graph embedded models.

We thank Yueming Wang and Eli Pugh for discussions and providing feedback on our manuscript. We also gratefully acknowledge the support of DARPA under Nos. FA865018C7880 (ASED), N660011924033 (MCS); ARO under Nos. W911NF-16-1-0342 (MURI), W911NF-16-1-0171 (DURIP); NSF under Nos. OAC-1835598 (CINES), OAC-1934578 (HDR), CCF-1918940 (Expeditions), IIS-2030477 (RAPID); Stanford Data Science Initiative, Wu Tsai Neurosciences Institute, Chan Zuckerberg Biohub, Amazon, Boeing, JPMorgan Chase, Docomo, Hitachi, JD.com, KDDI, NVIDIA, Dell. J. L. is a Chan Zuckerberg Biohub investigator.

References

Appendix A Problem Statement and Related Work

Episodic Training

A common approach is to learn a few-shot model on in an episodic manner, so that training and evaluation conditions are matched triantafillou2019meta. Note that training on support set examples during episode evaluation is distinct from training on . Many metric based meta-learners and optimization based meta-learners use this training method, including Matching Networks vinyals2016matching, Prototypical Networks snell2017prototypical, Relation Networks sung2018learning, and MAML finn2017model.

Non-episodic Baselines

Inspired by the transfer learning paradigm of pre-training and fine-tuning, a natural non-episodic approach is to train a classifier on all examples in at once. After training, the final classification layer is removed, and this neural network is used as an embedding function that maps images to

feature representations, including those from novel classes. It then fine-tunes the final classifier layer using support set examples from the novel classes. The models are a function of the parameters of a softmax layer,

. The softmax layer is formulated as the similarity between image feature embeddings and the classifier parameters where is the parameters for the class, is the cosine similarity function.

(6)

a.1 Related work

Few-Shot Learning

Canonical approaches to few-shot learning include memory-based gidaris2018dynamic, hariharan2017low, qiao2018few, metric learning ravi2016optimization, vinyals2016matching, snell2017prototypical, sung2018learning, and optimization-based methods finn2017model, rusu2018meta. However, recent studies have shown that simple baseline learning techniques (i.e., simply training a backbone, then fine-tuning the output layer on a few labeled examples) outperform or match performance of many meta-learning methods chen2019closer, dhillon2019baseline, prompting a closer look at the tasks triantafillou2019meta and contexts in which meta-learning is helpful for few-shot learning raghu2019rapid, tian2020rethinking.

Few-Shot Learning with Graphs

Beyond the canonical few-shot literature, studies have explored learning GNNs over episodes as partially observed graphical models garcia2017few and using GCNs to transfer knowledge of semantic labels and categorical relationships to unseen classes in zero-shot learning wang2018zero. Recently, Chen et al. presented a knowledge graph transfer network (KGTN), which uses a Gated Graph Neural Network (GGNN) to propagate information from base categories to novel categories for few-shot learning chen2019knowledge. Other works use domain knowledge graphs to provide task specific customization suo2020tadanet, and propagate prototypes liu2019prototype, liu2019learning. However, these models have highly complex architectures and consist of multiple sub-modules that all seem to impact performance.

Appendix B Experimental Setup

b.1 Mini-ImageNet

Dataset

The Mini-ImageNet dataset is a subset of ILSVRC-2012 deng2009imagenet. The classes are randomly split into , and classes for meta-training, meta-validation, and meta-testing respectively. Each class contains images. We use the commonly-used split proposed in vinyals2016matching.

Training details

We pre-train the feature extractor on using the method proposed by mangla2020charting. Activations in the penultimate layer are pre-computed and saved as feature embeddings of 640 dimensions to simplify the fine-tuning process. For an -way -shot problem, we sample novel classes per episode, sample support examples from those classes, and sample 15 query examples. During pre-training and meta-training stages, input images are normalized using the mean and standard-deviation computed on ILSVRC-2012. We apply standard data augmentation including random crop, left-right flip, and color jitter in both the training or meta-training stage. We use ResNet-18, ResNet-50 he2016deep, and WRN-28-10 zagoruyko2016wide for our backbone architectures. For pre-training WRN-28-10, we follow the original hyperparameters and training procedures for mangla2020charting. For meta-training ResNet-18, we follow the hyperparameters from chen2019closer. At evaluation time, we choose hyperparameters based on performance on the meta-validation set. Some implementation details are adjusted for each method. Specifically, for ProtoNet and LEO, we include base examples during an additional adaptation step per class. We show that these alterations have a minimal contribution to performance in Appendix C.

b.2 ImageNet-FS

Dataset

In the ImageNet-FS benchmark task, the ILSVRC-2012 categories are split into base categories and novel categories. From these, of the base categories and of the novel categories are used during cross-validation and the remaining base categories and novel categories are used for the final evaluation. Each base category has around training images and test images.

Training details

We follow the procedure by hariharan2017low to pre-train the ResNet-50 feature extractor, and adopt the Square Gradient Magnitude loss to regularize representation learning, which we scale by . The model is trained using the SGD algorithm with a batch size of , momentum of and weight decay of 0.0005. The learning rate is initialized as and is divided by

for every 30 epochs. During fine-tuning, we train for

iterations using the SGD algorithm with a batch size of 256, momentum of , weight decay of 0.005, and learning rate of .

b.3 Label Graph

WordNet ontology

ImageNet comprises of synsets, which are based on the WordNet ontology. For both the Mini-ImageNet and ImageNet-FS experiments, we first choose the synsets corresponding to the output classes of each task – for Mini-ImageNet and 1000 for ImageNet-FS. ImageNet provides IS-A relationships over the synsets, defining a DAG over the classes. We only consider the sub-graph consisting of the chosen classes and their ancestors. The classes are all leaves of the DAG.

Training details

The hyperparameter settings used for the node2vec-based graph regularization objective are in line with values published in grover2016node2vec. For all experiments, we set and temperature . We set the batch size to for Mini-ImageNet and for ImageNet-FS. Empirically, we find that setting the regularization scaling higher for lower shots results in better performance, and set for 1-,2-, and 5-shot tasks respectively.

Appendix C Ablations

c.1 Mini-ImageNet Ablations

c.1.1 Model re-implementations with adaptation

For episodically-evaluated few-shot models, it is common practice to disregard base classes during evaluation. To implement graph regularization, we include both base and novel classes during test time and perform a further adaptation step per task. We show that the boost in performance is not due to these modifications.

Model Backbone 1-shot 5-shot
ProtoNet ResNet-18 54.16 0.82 73.68 0.65
ProtoNet (adaptation) ResNet-18 54.86 0.73 74.14 0.50
ProtoNet (adaptation) + Graph (Ours) ResNet-18 55.47 0.73 74.56 0.49
LEO WRN 28-10 58.22 0.09 74.46 0.19
LEO (adaptation) WRN 28-10 57.85 0.20 74.25 0.17
LEO (adaptation) + Graph (Ours) WRN 28-10 60.93 0.19 76.33 0.17
Table 4: Validation of baseline model modifications.

c.1.2 Finding good parameter initializations for novel classes

Recent works have shown that good parameter initialization is important for few-shot adaptations raghu2019rapid. For example, Dhillion et al. dhillon2019baseline showed that initializing novel classifiers with the mean of the support set improves few-shot performance.

Here, we explore various methods of incorporating graph relations to improve parameter initialization for novel classes. We compare our proposed method with simpler methods to show that the our graph regularization method is boosting performance in a non-trivial manner. For each method, we keep the adaptation procedure the same, namely, the fine-tuning procedure described by Baseline++ chen2019closer.

We then vary parameter initialization using the following methods: (A) random initialization, (B) initializing novel classes with the weights of the closest training class in graph distance in the knowledge graph, (C) our method.

Model Backbone 1-shot 5-shot
+ Init A [mangla2020charting] WRN 28-10 64.93 0.18 83.18 0.11
+ Init B WRN 28-10 65.50 0.81 83.32 0.57
+ Init C WRN 28-10 66.93 0.65 83.35 0.53
Table 5: Mini-Imagenet with different parameter initialization methods (in % measured over 600 evaluation iterations).

c.2 ImageNet-FS Ablations

Here, we justify our model design decisions by considering alternatives. We first probe the benefits of using random walk neighborhoods by defining as only nodes that have direct edges with (“child-parent loss”). We try separately learning label graph embeddings, and passing the information to the classifier layer via “soft target” classification loss (“Independent graph w/ soft targets”). Results show that computing the graph loss directly on the classifier parameters is important for performance. Finally, we show that the quality of the label graph affects performance by removing layers of internal nodes of the WordNet hierarchy, starting from the bottom-most nodes (“Remove last 5, 10 layers”).

Ablation 1-shot
Ours 61.09
Child-parent loss 56.78
Independent graph w/ soft targets 56.22
Remove last 5 layers 57.80
Remove last 10 layers 54.86
Table 6: Imagenet-FS ablations. Experiment setups, in order from the top: our proposed method, using only child-parent edges, independently learning graph embeddings, removing 5 layers of the ImageNet hierarchy, and removing 10 layers of the ImageNet hierarchy.