Few-shot Classification via Adaptive Attention

08/06/2020 ∙ by Zihang Jiang, et al. ∙ National University of Singapore 0

Training a neural network model that can quickly adapt to a new task is highly desirable yet challenging for few-shot learning problems. Recent few-shot learning methods mostly concentrate on developing various meta-learning strategies from two aspects, namely optimizing an initial model or learning a distance metric. In this work, we propose a novel few-shot learning method via optimizing and fast adapting the query sample representation based on very few reference samples. To be specific, we devise a simple and efficient meta-reweighting strategy to adapt the sample representations and generate soft attention to refine the representation such that the relevant features from the query and support samples can be extracted for a better few-shot classification. Such an adaptive attention model is also able to explain what the classification model is looking for as the evidence for classification to some extent. As demonstrated experimentally, the proposed model achieves state-of-the-art classification results on various benchmark few-shot classification and fine-grained recognition datasets.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 7

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In recent years, few-shot learning problems gidaris2018dynamic; li2019revisiting; snell2017prototypical; sung2018learning; vinyals2016matching have attracted intensive attention, in which a model must be adapted to tackling unseen classes with only a few training examples. Humans are able to learn new concepts easily with only a few samples, but this could be highly difficult for a neural network model since it typically needs a lot of training data to extract meaningful features or otherwise it would suffer over-fitting.

To ensure a good generalization ability to novel classes, most few-shot learning methods train the model on a collection of tasks, each of which instantiates a few-shot learning problem and contains both query samples and a few support samples from different classes. Following the episodic learning paradigm thrun2012learning, previous works train the model by optimizing the initial one to fast adapt to various tasks, e.g. MAML finn2017model, or learning a good distance metric that can cluster samples of the same class together snell2017prototypical; vinyals2016matching; sung2018learning. However, MAML-alike optimization based methods usually suffer high optimization complexity. On the other hand, though metric-based methods are simpler and generally perform better, they lack flexibility and cannot offer strong adaptation to new tasks. This is mainly because, the metric learning procedure does not exploit the relations between support and query samples directly, and the resulted model embeds the query images into the latent space without concerning the support samples.

When conducting few-shot classification, humans usually take several glimpses at the support and query samples, pay attention to some critical parts and make the decision dicarlo2012does; logothetis1996visual. Such a procedure is natural and can be leveraged in meta learning. Properly modeling such a procedure of attention can help fast localize critical parts and enhance performance of few-shot classification. Despite several attention mechanisms that have been shown to be effective in the fully-supervised setting wang2017residual; xiao2015application, in few-shot classification, existing methods are still unable to obtain an adaptive attention map w.r.t. the few support samples accurately.

In this work, we develop an efficient meta-reweighting strategy to adapt the representations of query samples by incorporating the representations of support samples. Channel-level reweighting coefficients are first generated from the support features through a small neural network and then applied to query features by a channel-wise multiplication to emphasize some important feature maps. The resulted representation could therefore contain the information from both the query and support samples. Based on the adapted representations after reweighting, we further propose a new support-adaptive attention mechanism: an adaptive attention map is first generated over the query sample from the merged representation by an attention module, which indicates the spatial location of the object within the query belonging to the class of corresponding support samples. Then the attention map is used to refine the representation of the query by filtering out irrelevant information w.r.t. the support samples before classifying the query. With such an adaptive attention mechanism, a boosted few-shot learning ability can be achieved. Meanwhile, the whole model is light-weight and easy to optimize.

To summarize, our contributions are two-fold:

  1. We develop a feature meta-reweighting strategy to extract and exploit the information of support examples, which is different from the traditional metric learning or optimization based meta-learning methods.

  2. We propose an attention mechanism based on the meta-reweighting strategy to localize the region of interest in query samples w.r.t. support samples, which helps refine the classification performance and meanwhile explains the behavior of the model to some extent.

2 Related Work

2.1 Few-Shot Learning

Generally, few-shot learning methods can be divided to meta-learning based methods and metric learning based ones. In addition, some recent works also adopt transfer leaning Sun2018MetaTransferLF and data augmentation mehrotra2017generative; NIPS2018_7549 to achieve high accuracy in few-shot classification tasks, but they are based on very deep network he2016deep; zagoruyko2016wide which involves more intensive pretraining and computation.

Meta-learning based methods

In finn2017model, a model-agnostic meta-learning (MAML) approach was proposed to make the model adapt to novel classes within some steps of optimization during testing with only a few samples. Instead of directly optimizing the model w.r.t. the target task, MAML-alike methods li2017meta; rusu2018meta tried to find a good initialization for the parameters of a model, which can then generalize to the new task by a few steps of gradient decent w.r.t. the initial parameters. Recent works antoniou2018train and rusu2018meta made some improvements on MAML to enhance model performance and stability. The work khodadadeh2018unsupervised introduced the idea of MAML to unsupervised meta-learning and also achieved impressive results. Another straightforward way for quickly acquiring new knowledge is reviewing the learned knowledge. In graves2014neural, a stable external memory was used for a neural network to help achieve boosted performance. The RNN memory-based methods like santoro2016meta adopted an LSTM to interact with the external memory. The work ravi2016optimization trained an LSTM-based meta-model to serve as an optimizer to another learner model. Their meta model also generated a task-common initialization of parameters for the learner model in order to quickly adapt to the test environment.

Metric learning

One simple metric learning based approach is embedding images into a feature space and constraining the feature vector of the same class to be close in Euclidean distance or cosine distance. The works 

vinyals2016matching; snell2017prototypical adopted this idea and further extended it to zero-shot learning and other fields. The work laitask enhanced the Prototypical Network by adding a reweighting module based on the task. In sung2018learning, the metric was replaced with a neural network to measure two embedded features, and decide the relations between them. This practice is much better than directly using Euclidean or cosine distance since the embedding space may be very complicated. Deep metric learning is also widely used in representation learning hoffer2015deep, which demonstrates good performance. Our method can also be seen as learning a deep metric that measures similarity between query and support objects. Note, if we can obtain the location of the object in the image, we are actually comparing two objects instead of two images, which is easier and more effective.

2.2 Attention Models

Recent work gidaris2018dynamic proposed an attention-based classification weight generator which achieved impressive results. Our method, however, is mainly based on the spatial attention concerning the precise location of the object in the image. Actually, in the past decades, the spatial attention mechanism was studied by the vision community with the goal of both classification and object localization. It is very challenging to learn to localize objects with only image-level labels. zhou2016learning proposed to replace some high-level layers of a classification network by a global average pooling (GAP) lin2013network followed by a fully connected layer to generate discriminative class activation maps (CAM). zhang2018adversarial improved it by using a fully convolutional network (FCN) and proved that the attention maps can be obtained by directly selecting the feature maps of the last convolution layer. However, none of these methods is applicable to few-shot setting. We extend this idea to few-shot classification in this work. Our adaptive attention module is somewhat similar to the way of getting object localization maps in zhang2018adversarial, but their work was proposed for localization tasks.

3 Method

We consider a -way -shot classification problem. In each few-shot classification task , we are given a set of support samples and a batch of query samples . The samples in the support set are from categories, each of which has labeled samples, i.e., . The few-shot classification model is required to acquire knowledge from the support set and classify the query samples accurately.

3.1 Model Workflow

Figure 1: Illustration of our framework. The images first go 1 through a feature extractor. Then the attention module integrates the support and query feature to get adaptive attention maps. Then point-wisely multiplying the attention maps with the query feature gives refined query features for the classifier which outputs the similarity scores between the support and the query.

An overview of our proposed adaptive attention network is illustrated in Fig. 1. It consists of three components: a feature extractor similar to most classical few-shot models garcia2017few; gidaris2018dynamic; li2019revisiting; sung2018learning; snell2017prototypical; vinyals2016matching, a classifier and an adaptive attention module . The classifier and attention modules are both light-weight as explained later.

We first explain the testing phase of the proposed model by taking the -way 1-shot classification tasks as an example. Both the query and support samples will go through the feature extractor at first to produce query and support feature maps respectively. Then the adaptive attention module generates attention maps for the query feature conditioned on the support features. Specifically, for a query feature and one of the support features, the attention module generates an attention map roughly localizing the object belonging to the support class in the query image. This is called an adaptive attention map for the query conditioned on one specific support. If the query image does not contain any object belonging to the class of the support image, the attention map will only highlight some background. These attention maps are then applied on the query feature to generate mask-pooled query features. These attended features are fed into the classifier to give a score indicating the confidence of the query belonging to the class of a specific support and the query is classified based on the largest confidence scores. With such an adaptive attention mechanism, our model can provide better sample representation to ease the downstream classification procedure. We now proceed to explain details of the attention module.

3.2 Attention Module

Figure 2: Illustration of our attention module. It consists of a meta-weight generator to generate class-specific weights and a spatial attention generator to precisely locate the object in query data points w.r.t. support data points.

Taking as input a query feature and a support feature, the adaptive attention module aims to directly generate spatial attention indicating the location of the object belonging to the support class in the query image. The key is how to adaptively generate the attention map over the query based on the information from the specific support. In this work, we propose a meta-reweighting based attention generation approach. The support features from the feature extractor are first applied to modulate the query features through channel-wise multiplication. Then the modulated query features with support information are used to generate the attention maps. In this way, the support information can be efficiently integrated into the attention generation process and the generated attention map is able to help select important regions from the query features conditioned on the support. Such an approach is beneficial to few-shot learning scenarios with only scarce support samples for training an attention model through fully supervised learning.

As shown in Fig. 2, the attention module generates the support-adaptive attention through a meta-weight generator and a spatial attention generator .

Meta-weight generator

Humans recognize an object mainly based on some discriminative feature of the support examples in their brain. For example, when trying to find a cat in a room, the focus may be something hairy or with a tail. Motivated by this, we adopt a meta-reweighting strategy that merges the extracted features with a class-dependent weight vector. In particular, for a support point with label , and a query point denote their corresponding extracted features as and . Then we apply the meta-weight generator on to get a weight vector for class . The channel number of feature maps in equals the number of weights in . Then we can obtain a class-specific feature

(1)

where denotes the channel-wise multiplication. Note that sung2018learning also merged the feature maps of the query and support points to get relation scores by straightly concatenating the two feature maps. Differently, our channel-wise multiplication preserves the spatial information of the query feature while emphasizing some feature maps that are crucial for classifying the support class. If does not belong to the class , the emphasized feature maps do not contain useful information for class , in which case the classifier predicts a lower score.

Spatial attention generator

Our spatial attention generator consists of an FCN of 2 convolutional layers with one-channel final output map. This architecture is inspired by zhang2018adversarial which proposed an efficient way to get the region of interest for the deep convolutional network in weakly supervised setting. The main goal of this generator is to get important spatial regions in the query image w.r.t. the support sample. Taking as input the class-specific feature of a query, the output map then serves as an attention map for the corresponding support on the input query feature map and provides attention at the spatial level that helps refine it.

We experimentaly justify the advantage of using such a reweighting strategy instead of directly concatenating the feature maps like Relation Network sung2018learning in Section 4.3. Through visualization, one will see the output of FCN in the spatial attention generator indeed serves as an accurate attention map. Given a support feature and a query feature , we can get a precise attention map

(2)

which indicates the location of the object that belongs to the corresponding support class in the query.

3.3 Classifier

With the adaptive attention module, we can obtain the attention map for each query sample. We then point-wisely multiply the attention map with the query feature map. In particular, given an extracted query feature , we can get a class-dependent attention map w.r.t. the support feature . The feature fed into the final classifier can then be refined by the attention map to

(3)

where denotes the point-wise multiplication. In this refined feature map, the region of interest is highlighted by the class-dependent attention map provided by the attention module. The final classifier will concentrate more on this region, thus get better performance.

More specifically, our classifier consists of a pooling layer followed by several linear layers, taking as input the class-dependent attention maps and the query feature and outputting a single score for each pair of query and attention map. The detailed architecture will be given in the following section. The final output is a single score representing the confidence of the query point belonging to the corresponding support class. We rely on the largest score for the final decision.

Take as input a support and a query feature denoted as and . The final output score will then be , where

is the adaptive attention map. Though we can use activation function like

to force the output to be positive, the function is still asymmetric. So we add its symmetric form to get

(4)

as the final output score. The whole framework can then be reinterpreted as deep semimetric learning. Note that the attention map in the original form is trying to find the location of the support object in the query image, and this symmetric form can also be interpreted as finding the location of the query object in the support image, which helps improve performance when the support image contains not only objects from the support class but also other distractor classes.

3.4 Training

Following the episodic training scheme vinyals2016matching, we randomly select classes from the training dataset with samples each to form the support set examples for training. A fraction of the rest data in those classes serve as the query set. The model is then trained on these classification tasks.

We first explain the loss function for training the attention module to produce adaptive attention maps.

Given the support set with label for each class and be a query sample with label . As described previously, the meta-weight generator outputs a weight vector for each support sample . By averaging the weights of each class, we can get class-specific weights . By Eqn. (1), our model generates class-specific features using the class-specific weights. By feeding them to the spatial attention generator, we can get attention maps . Then by applying a global average pooling layer, we can get scores indicating the confidence that query belongs to class . Here denotes the average of the attention map . The cross-entropy loss for the attention module is

(5)

We then explain the loss function for training the classifier and the whole model. Using Eqn. (4) Our model outputs scores , each of which indicates the confidence of the query samples for being classified into the same category as . The cross entropy loss is

(6)

The final loss for end-to-end training the whole model is then .

4 Experiment

4.1 Setting

We conduct experiments to evaluate the effectiveness of our proposed model on five datasets, which include: Omniglot lake2011one, miniImageNet vinyals2016matching, CUB-200 WelinderEtal2010, Stanford Dogs KhoslaYaoJayadevaprakashFeiFei_FGVC2011, Stanford Cars KrauseStarkDengFei-Fei_3DRR2013. Details of the datasets and splits are provided in the supplementary material. Among the above five datasets, CUB-200, Stanford Dogs and Stanford Cars are originally proposed for fine-grained recognition in the fully supervised setting, and recently applied to the more challenging few shot setting li2019revisiting; wei2018piecewise; chen2019closerfewshot, namely fine-grained few shot classification. Unlike mini

ImageNet, the variances in these datasets are small, and each class contains only around 100 images or less. Therefore they are more challenging than the generic datasets Omniglot and

miniImageNet since the model is forced to learn to find more accurate evidence to make a decision.

All experiments are conducted in 5-way 1-shot or 5-way 5-shot scenario. In testing, we randomly run our model for 600 episodes on each dataset except Omniglot on which we run 1,000 episodes for a fair comparison. In each episode, we randomly batch 15 query images per class to form a query set of 75 images. The classification accuracy is then calculated by averaging the accuracies of the 600 (1,000 for Omniglot) episodes.

We adopt the commonly used 4 layers convolutional network (Conv-64Fsnell2017prototypical; vinyals2016matching which has 64 channels in each layer as our feature extractor. We also evaluate our method with the ResNet-256F backbone used in gidaris2018dynamic; li2019revisiting; mishra2017simple

, which is deeper. In order to get a more precise attention map, we remove the last max pooling layer. As for the attention module, we use a spatial pyramid pooling (SPP) layer followed by three linear layers with

and channels respectively for the meta-weight generator and two convolutional layers with and channels followed by a global average pooling layer for the spatial attention generator. The amount of additional parameters for the attention module is rather small, less than of that for the Conv-64F feature extractor. The classifier consists of an SPP layer followed by three linear layers with and channels respectively.

4.2 Comparison with State-of-the-arts

Results on generic datasets

Method Backbone 5-way 1-shot 5-way 5-shot
Matching Network vinyals2016matching Conv-64F 43.56

55.31

ProtoNet snell2017prototypical Conv-64F 49.42

68.20

GNN garcia2017few Conv-256F 50.33

66.41

Relation Nerwork sung2018learning Conv-64F 50.44

65.32

DN4 li2019revisiting Conv-64F 51.24

71.02

MAML finn2017model Conv-32F 48.70

63.11

Dynamic-Net gidaris2018dynamic Conv-64F 56.20

72.81

Dynamic-Net gidaris2018dynamic ResNet-256F 55.45

70.13

SNAIL mishra2017simple ResNet-256F 55.71

68.88

Ours Conv-64F 56.12

71.48

Ours fine-tune Conv-64F 56.33

72.83
Ours ResNet-256F 59.12 72.36

Ours fine-tune ResNet-256F 59.26 74.59
Table 1: Few-shot classification accuracy (%) with confidence intervals on miniImageNet, compared with SOTAs. For the ResNet-256F architecture we refer to mishra2017simple.

Table 1 reports few-shot classification performance on miniImageNet. For each task, we also use the labeled data to perform task-specific fine-tuning for one iteration and report the results referred as “our fine-tune” model. We can see that our model outperforms all the state-of-the-art models on miniImageNet for both the Conv-64F and ResNet-256F backbone. Noticeably, our model can offer the spatial location that it is looking at, as shown in Figure 3, which largely lifts the interpretability of the model on decision making. Our model also achieves comparable performance with state-of-the-arts on Omniglot. Due to space limit, we defer the detailed results to the supplementary material.

Figure 3: The adaptive attention map generated w.r.t. the support image is rendered on the query image. Top row is the query image and left column is support image. (Best viewed in color.)

Results on fine-grained datasets

We also apply our method on fine-grained classification datasets: CUB-200, Stanford Dogs and Stanford Cars, which are more challenging. A classical DNN tends to suffer severe overfitting on such small datasets. We conduct similar experiments as on miniImageNet for both 5-way 1-shot and 5-way 5-shot scenarios. As observed from Table 2, for 1-shot learning, our method outperforms the state-of-the-arts by a large margin. Such results clearly demonstrate the strong learning ability from very few shots of our proposed model and the benefits of the adaptive attention. In the 5-shot setting, our method performs similarly to the latest method li2019revisiting.

Method CUB-200 Stanford Dogs Stanford Cars
1-shot 5-shot 1-shot 5-shot 1-shot 5-shot
FSFG wei2018piecewise 42.10

62.48

28.78

46.92

29.63

52.28

ProtoNet snell2017prototypicalli2019revisiting 37.36

45.28

37.59

48.19

40.90

52.93

GNN garcia2017fewli2019revisiting 51.83

63.69

46.98

62.27

55.85

71.25

DN4 li2019revisiting 53.15

81.90 45.73

66.33

61.51 89.60
ProtoNet snell2017prototypical 54.54 71.02

50.57 72.60 55.70

68.68

Ours 64.51 78.62 61.74 77.37 70.73 87.72
Table 2: Few-shot classification accuracy (%) with confidence intervals on fine-grained classification datasets, compared with SOTAs. The backbones are all Conv-64F. Here denotes results from our implementation.

4.3 Ablation Study

Figure 4: Comparison of the attention map generated by our meta-reweighting strategy and direct concatenation. Our attention method can localize the dog more accurately; the map produced by concatenation however is quite confusing. (Best viewed in color.)

We perform a set of ablation studies to investigate effect of each component in our proposed model. The results are summarized in Table 3.

Reweighting vs. concatenating for attention map generation

As shown in Table 3, replacing the meta-reweighting strategy with concatenation leads to no improvement. Note this is actually a variant of Relation Network sung2018learning. The attention module can act as a classifier which achieves accuracy in one-shot setting and adding an additional classifier does not improve the model. The generated map cannot indicate location of the object in the query image, as shown in Figure 4, confusing the classifier. The reason why directly concatenating query and support feature cannot produce a precise attention map is straight-forward: it destroys the spatial information of the query image after concatenation. Instead, our meta-reweighting strategy preserves the spatial information and the generated attention map precisely locates the dog as shown in the second row of Figure 4.

Effect of test data augmentation

A key observation is that the weights generated by the meta weight generator are sometimes not so reliable since there exist distracting classes in the support images. We then apply a test data augmentation to enhance the weight generator by randomly cropping and flipping the input images and compute the mean of these vectors obtained from these augmented support data. As shown in Table 3, this stabilizes the weight generation stage and thus helps the classification.

combination TA AC 5-way 1-shot 5-way 5-shot
Ours Concatenate 51.31

63.40

Ours Concatenate 51.79

64.13

Ours Reweight 52.81

68.91

Ours Reweight 54.98

70.02

Ours Reweight 56.12

71.48

Ours fine-tune Reweight 56.33

72.83

Table 3: Ablation studies on miniImageNet. The backbones are Conv-64F. TA: test data augmentation; AC: the classifier component.

Necessity of the classifier component

As aforementioned, the attention module can also perform the classification. To investigate its performance, we also remove the classifier and only use the attention module as a classifier. In this way, our attention module has to generate the spatial attention while performing recognition. From the results, one can observe the attention map can also be used for classification with satisfactory performance, demonstrating the adaptive attention maps indeed incorporate discriminative information from the support. But it does not perform very well for classification compared with the additional classifier which specially focuses on classification with the help of the attention module. This demonstrates the necessity of the additional classifier.

5 Conclusion

In this paper, we present an efficient framework for few shot classification. It uses a meta-reweighting strategy together with an attention module to find the location of the query item w.r.t. the support samples and uses this attention map to adapatively refine query representation. Experimental results demonstrate the power of our proposed method especially in the one-shot setting. It outperforms all state-of-the-art models by a large margin across all real-world image datasets. Also, our method is far simpler than the recently proposed meta-learning methods which need a lot more computation and careful training. The visualization results also show potential for help understand and improve few-shot classification models.

References

6 Datasets

Omniglot lake2011one

It is a hand-written character based dataset, and we apply the split and augmentation policy in vinyals2016matching. The original 1,623 classes are augmented to new classes through , and rotations. The 1,200 original classes and those after rotations are used for training, and the remaining 423 classes plus their rotated images are used for testing. All the images are resized to .

miniImageNet vinyals2016matching

It is a mini version of ImageNet imagenet_cvpr09 as a benchmark for few shot classification. It contains 60,000 color images from 100 classes, with 600 images in each class. Following the splits used in ravi2016optimization, we use 64 classes for training, 16 classes for validation and the remaining 20 classes for testing. All images are resized to .

Cub-200 WelinderEtal2010

The Caltech-UCSD Birds 200 (CUB-200) is a dataset with images of 200 bird species. It contains images in total. Following the split in li2019revisiting, we use 130 classes for training, 20 classes for validation, and the remaining 50 classes for testing respectively.

Stanford Dogs KhoslaYaoJayadevaprakashFeiFei_FGVC2011

This dataset is also a subset of ImageNet. It contains 20,580 images of 120 breeds (classes) of dogs. Similarly, we use the split in li2019revisiting to get 70, 20 and 30 classes for training, validation and testing respectively.

Stanford Cars KrauseStarkDengFei-Fei_3DRR2013

This dataset contains 16,185 images of 196 classes of cars. Following li2019revisiting, we take 130, 17 and 49 classes for training, validation and testing respectively.

7 Results on Omniglot

Following snell2017prototypical, we train the model on the 60-way 1-shot and 60-way 5-shot setting, and provide result for 5-way 1-shot and 5-way 5-shot scenarios in Table 4.

Method 5-way 1-shot 5-way 5-shot
Matching Networkvinyals2016matching 98.1 98.9
Prototypical Networksnell2017prototypical 98.8 99.7
GNNgarcia2017few 99.2 99.7
Relation Networksung2018learning 99.6

99.8

MAMLfinn2017model 98.4

99.9

ours 99.2

99.7

Table 4: 5-way 1-shot and 5-way 5-shot classification accuracy(%) results with confidence intervals on Omniglot compare with other state or the art method, the backbone used are Conv-64F.