Improved Few-Shot Visual Classification

12/07/2019 ∙ by Peyman Bateni, et al. ∙ The University of British Columbia 1

Few-shot learning is a fundamental task in computer vision that carries the promise of alleviating the need for exhaustively labeled data. Most few-shot learning approaches to date have focused on progressively more complex neural feature extractors and classifier adaptation strategies, as well as the refinement of the task definition itself. In this paper, we explore the hypothesis that a simple class-covariance-based distance metric, namely the Mahalanobis distance, adopted into a state of the art few-shot learning approach (CNAPS) can, in and of itself, lead to a significant performance improvement. We also discover that it is possible to learn adaptive feature extractors that allow useful estimation of the high dimensional feature covariances required by this metric from surprisingly few samples. The result of our work is a new "Simple CNAPS" architecture which has up to 9.2 trainable parameters than CNAPS and performs up to 6.1 the art on the standard few-shot image classification benchmark dataset.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 3

page 5

page 14

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep learning successes have led to major computer vision advances [DBLP:journals/corr/abs-1907-09408-object-detection-survey, 8441512-image-classification-survey, Hossain:2019:CSD:3303862.3295748-image-captioning-survey]. However, most methods behind these successes have to operate in fully-supervised, high data availability regimes. This limits the applicability of these methods, effectively excluding domains where data is fundamentally scarce or impossible to label en masse. This inspired the field of few-shot learning [Wang:2019:SZL:3306498.3293318-survey-of-zero-shot-learning, DBLP:journals/corr/abs-1904-05046-survey-on-few-shot-learning], methods which aim to computationally mimic human reasoning and learning from limited data.

(a) Squared Euclidean Distance
(b) Mahalanobis Distance
Figure 3: Class-covariance metric: t-SNE visualization of embedded image features (points), class decision boundaries, and test query instances (stars) of standard -based (left) and our proposed class-covariance-based distance metric (right) classifier on top of CNAPS-adopted [requeima2019fast] image features. For convenience of comparison, the plot on the right is visualized in a transformed space where decision boundaries are linear. The features are obtained from a task taken from the out-of-domain Traffic Signs dataset. The task contains support samples spanning across classes, with query instance depicted for each class. Illustrating the advantage deriving from our choice of the Mahalanobis distance for classification, the decision boundaries on the (right) are better aligned with the support embeddings, and query instances of three out of five classes are misclassified by -based classifier (left). Correctly classified queries are denoted by green checks and incorrect ones by red .

The goal of few-shot learning is to automatically adapt models such that they work well on instances from classes not seen at training time, given only a few labelled examples for each new class. In this paper, we focus on few-shot image classification where the ultimate aim is to develop a classification methodology that automatically adapts to new classification tasks at test time, and particularly in the case where only a very small number of labelled “support” images are available per class.

Few-shot learning approaches typically take one of two forms: 1) nearest neighbor approaches and their variants, including matching networks [vinyals2016matching], which effectively apply nearest-neighbor or weighted nearest neighbor classification on the samples themselves, either in a feature [DBLP:journals/corr/abs-1905-01436-edge-labelling-gnn, koch2015siamese, garcia2018fewshot] or a semantic space [Frome-NIPS2013_5204]; or 2) embedding methods that effectively distill all of the examples to a single prototype per class, where a prototype may be learned [gidaris2019generating, requeima2019fast] or implicitly derived from the samples [snell2017prototypical] (e.g. mean embedding). The prototypes are often defined in feature or semantic space (e.g. word2vec [DBLP:journals/corr/abs-1902-07104-elementai]

). Most research in this domain has focused on learning non-linear mappings (often expressed as neural nets) from images to the embedding space subject to a pre-defined metric in the embedding space used for final nearest class classification (usually cosine similarity between query image embedding and class embedding). Most recently CNAPS

[requeima2019fast] achieved state of the art (SoTA) few-shot visual image classification by utilizing sparse FiLM [perez2018film]

layers within the context of episodic training to get around problems that arise from trying to adapt the entire embedding neural network using few support samples.

Overall much less attention has been given to the metric used to compute distances for classification in the embedding space. Presumably this is because common wisdom dictates that flexible non-linear mappings are ostensibly able to adapt to any such metric, making the choice of metric apparently inconsequential. In practice, as we find in this paper, the choice of metric is quite important. In [snell2017prototypical] the authors analyze the underlying distance function used in order to justify the use of sample means as prototypes. They argue that Bregman divergences [banerjee2005clustering]

are the theoretically sound family of metrics to use in this setting, but only utilize a single instance within this class — squared Euclidean distance, which they find to perform better than the more traditional cosine metric. However, the choice of Euclidean metric involves making two flawed assumptions: 1) that feature dimensions are un-correlated and 2) that they have uniform variance. Further it is insensitive to the distribution of within-class samples with respect to their prototype and recent results

[NIPS2018_7352-tadam, snell2017prototypical] suggest that this is problematic. Modeling this distribution (in the case of [banerjee2005clustering] using extreme value theory) is, as we find, a key to better performance.

Our Contributions.

  1. A robust empirical finding of a significant 6.1% improvement on average over SoTA (CNAPS [requeima2019fast]) in few-shot image classification, obtained by utilizing a test-time-estimated class-covariance-based distance metric, namely the Mahalanobis distance [galeano2015mahalanobis], in final, task-adapted classification.

  2. The surprising finding that we are able to estimate such a metric even in the few shot classification setting, where the number of available support examples per class is far too few in theory to estimate the required class-specific covariances.

  3. A new “Simple CNAPS” architecture that achieves this performance despite removing 788,485 parameters (3.2%-9.2% of the total) from original CNAPS architecture, replacing them with fixed, not-learned, deterministic covariance estimation and Mahalanobis distance computations.

  4. Evidence that should make readers question the common understanding that CNN feature extractors of sufficient complexity can adapt to whatever the final metric is (be it cosine similarity/dot product or otherwise).

Figure 4: Approaches to few-shot image classification: organized by image feature extractor adaptation scheme (vertical axis) versus final classification method (horizontal axis). Our method (Simple CNAPS) partially adapts the feature extractor (which is architecturally identical to CNAPS) but is trained with, and uses, a fixed (rather than adapted) Mahalanobis metric for final classification.

2 Related Work

Figure 5: Comparison of the feature extraction and classification in CNAPS versus Simple CNAPS:

Both CNAPS and Simple CNAPS share the feature extraction adaptation architecture detailed in Figure 

6. CNAPS and Simple CNAPS differ in how distances between query feature vectors and class feature representations are computed for classification. CNAPS uses a trained, adapted linear classifier whereas Simple CNAPS uses a differentiable but fixed and parameter-free deterministic distance computation. Components in light blue are have parameters that are trained, namely in both models and in the CNAPS adaptive classification. As shown, CNAPS classification requires k parameters while Simple CNAPS is fully deterministic.

The last decade’s few-shot learning work [DBLP:journals/corr/abs-1904-05046-survey-on-few-shot-learning] can be differentiated along two main axes: 1) how images are transformed into vectorized embeddings, and 2) how “distances” are computed between vectors in order to assign labels. This is shown in Figure 4.

Siamese networks [koch2015siamese], an early approach to few-shot learning and classification, used a shared feature extractor to produce embeddings for both the support and query images. Classification was then done by picking the smallest weighted L1 distance between query and labelled image embeddings. Relation networks [sung2018learning] (and recent GCNN variations [DBLP:journals/corr/abs-1905-01436-edge-labelling-gnn, garcia2018fewshot]

) extended this by parameterizing and learning the classification metric using a Multi-Layer Perceptron (MLP). Matching networks

[vinyals2016matching] learned distinct feature extractors for support images and query images which were then used to compute cosine similarities for classification.

The feature extractors used by these models were, notably, not adapted to test-time classification tasks. It has become established that adapting feature extraction to new tasks at test time is generally a good thing to do. Fine tuning transfer-learned networks

[DBLP:journals/corr/YosinskiCBL14-finetune] did this by continuing training of the feature extractor network using the task-specific support images but found limited success due to problems related to overfitting to, the generally very few, support examples. MAML [finn2017model] (and its many extensions [DBLP:journals/corr/MishraRCA17-snail, DBLP:journals/corr/abs-1803-02999-reptile, DBLP:conf/iclr/RaviL17-meta-lstm]) mitigated this issue by learning a set of meta-parameters that specifically enabled the feature extractors to be adapted to new tasks given few support examples using very few gradient steps.

The two methods most similar to our own are CNAPS [requeima2019fast] (and the related TADAM [NIPS2018_7352-tadam]) and Prototypical networks [snell2017prototypical]. CNAPS is a few-shot adaptive classifier based on conditional neural processes (CNP) [DBLP:journals/corr/abs-1807-01613-cnp]. It is the state of the art approach for few-shot image classification [requeima2019fast]. It uses a pre-trained feature extractor augmented with FiLM layers [perez2018film] that are adapted for each task using the support images specific to that task. CNAPS uses a dot-product distance in a final linear classifier; the parameters of which are also adapted at test-time to each new task. We describe CNAPS in greater detail when describing our method.

Prototypical networks [snell2017prototypical] do not use a feature adaptation network, they instead use a simple mean pool operation to form class “prototypes.” Euclidean distances to these prototypes are then subsequently used for classification. Their choice of the distance metric was motivated by the theoretical benefits of Bregman divergences [banerjee2005clustering], a family of functions of which the squared Euclidean distance is a member.

Our work differs from CNAPS [requeima2019fast] and Prototypical networks [snell2017prototypical] in the following ways. First, while CNAPS has demonstrated the importance of adapting the feature extractor to a specific task, we show that adapting the classifier is actually unnecessary to obtain good performance. Second, we demonstrate that an improved choice of Bregman divergence can significantly impact accuracy. Specifically we show that regularized class-specific covariance estimation from task-specific adapted feature vectors allows the use of the Mahalanobis distance for classification, achieving a significant improvement over state of the art. A high-level diagrammatic comparison of our “Simple CNAPS” architecture to CNAPS can be found in Figure 5.

Figure 6: Overview of the feature extractor adaptation methodology in CNAPS: task encoder provides the adaptation network at each block with the task representations to produce FiLM parameters . For details on the auto-regressive variation (AR-CNAPS), architectural implementations, and FiLM layers see Appendix A. For an in-depth explanation, refer to the original paper [requeima2019fast].

3 Formal Problem Definition

We frame few-shot image classification as an amortized classification task. Assume that we have a large labelled dataset of images and labels . From this dataset we can construct a very large number of classification tasks by repeatedly sampling without replacement from . Let uniquely identify a classification task. We define the support set of a task to be and the query set where where are vectorized images and are class labels. Our objective is to find parameters of a classifier that maximizes .

In practice is constructed by concatenating large image classification datasets and the set of classification tasks is sampled in a more complex way than simply without replacement. In particular constraints are placed on the relationship of the image label pairs present in the support set and those present in the query set. For instance, in few-shot learning the constraint that the query set labels are a subset of the support set labels is imposed. With this constraint imposed the classification task reduces to correctly assigning each query set image to one of the classes present in the support set. Also, in this constrained few-shot classification case, the support set can be interpreted as being the “training data” for implicitly training (or adapting) a task-specific classifier of query set images.

4 Method

Our classifier shares feature adaptation architecture with CNAPS [requeima2019fast], but deviates from CNAPS by replacing their adaptive classifier with a simpler classification scheme based on estimating Mahalanobis distances. To explain our classifier, which we call “Simple CNAPS,” we first detail CNAPS in 4.1, before presenting our classifier in 4.2.

4.1 Cnaps

Conditional Neural Adapative Processes (CNAPS) consist of two elements: a feature extractor and a classifier, both of which are task-adapted. Adaptation is performed by trained adaptation modules that take the support set.

The feature extractor architecture used in both CNAPS and Simple CNAPS is shown in Figure 6. It consists of a ResNet18 [DBLP:journals/corr/HeZRS15-resnet]

network pre-trained on ImageNet

[russakovsky2015imagenet] which also has been augmented with FiLM layers [perez2018film]. The parameters of the FiLM layers can scale and shift the extracted features at each layer of the ResNet18, allowing the feature extractor to focus and disregard different features on a task-by-task basis. A feature adaptation module is trained to produce based on the support examples provided for the task.

The feature extractor adaptation module consists of two stages: support set encoding followed by film layer parameter production. The set encoder , parameterized by a deep neural network, produces a permutation invariant task representation based on the support images . This task representation is then passed to which then produces the FiLM parameters for each block in the ResNet. Once the FiLM parameters have been set, the feature extractor has been adapted to the task. We use to denote the feature extractor adapted to task . The CNAPS paper [requeima2019fast] also proposes an auto-regressive adaptation method which conditions each adaptor on the output of the previous adapter . We refer to this variation as AR-CNAPS but for conciseness we omit the details of this architecture here, and instead refer the interested reader to [requeima2019fast] or to Appendix A.1.

Classification in CNAPS is performed by a task-adapted linear classifier where the class probabilities for a query image

are computed as . The classification weights and biases are produced by the classifier adaptation network forming where for each class in the task, the corresponding row of classification weights is produced by from the class mean The class mean is obtained by mean-pooling the feature vectors of the support examples for class extracted by the adapted feature extractor . A visual overview of the CNAPS adapted classifier architecture is shown in Figure 5, bottom left, red.

(a) Euclidean Norm
(b) Mahalanobis Distance
Figure 9: Problematic nature of the unit-normal assumption: The Euclidean Norm (left) assumes embedded image features are distributed around class means according to a unit normal. The Mahalanobis distance (right) considers cluster variance when forming decision boundaries, indicated by the background colour.

4.2 Simple CNAPS

In Simple CNAPS, we also use the same pre-trained ResNet18 for feature extraction with the same adaptation module , although, because of the classifier architecture we use, it becomes trained to do something different than it does in CNAPS. This choice, like for CNAPS, allows for a task-specific adaptation of the feature extractor. Unlike CNAPS, we directly compute

(1)

using a deterministic, fixed

(2)

Here is a covariance matrix specific to the task and class. As we cannot know the value of ahead of time, it must be estimated from the feature embeddings of the task-specific support set. As the number of examples in any particular support set is likely to be much smaller than the dimension of the feature space, we use a regularized estimator

(3)

formed from a convex combination of the class-within-task and all-classes-in-task covariance matrices and respectively.

We estimate the class-within-task covariance matrix using the feature embeddings of all where is the set of examples in with class label .

If the number of support instance of that class is one, i.e.  = 1, then we define

to be the zero matrix of the appropriate size. The all-classes-in-task covariance

is estimated in the same way as the class-within-task except that it uses all the support set examples regardless of their class.

We choose a particular, deterministic scheme for computing the weighting of class and task specific covariance estimates, . This choice means that in the case of a single labeled instance for class in the support set, a single “shot,” . This can be viewed as increasing the strength of the regularization parameter relative to the task covariance . When , becomes and only partially favors the class-level covariance over the all-class-level covariance. In a high-shot setting, tends to and mainly consists of the class-level covariance. The intuition behind this formula for is that the higher the number of shots, the better the class-within-task covariance estimate gets, and the more starts to look like . We considered other ratios and making ’s learnable parameters, but found that out of all the considered alternatives the simple deterministic ratio above produced the best results. The architecture of the classifier in Simple CNAPS appears in Figure 5, bottom-right, blue.

5 Theory

The class label probability calculation appearing in Equation 1 corresponds to an equally-weighted exponential family mixture model [snell2017prototypical], where the exponential family distribution is uniquely determined by a regular Bregman divergence [banerjee2005clustering]

(4)

for a differentiable and strictly convex function F. The squared Mahalanobis distance in Equation 2 is a Bregman divergence generated by the convex function and corresponds to the multivariate normal exponential family distribution. This allows us to view the class probabilities Equation 1

as the “responsibilities” in a Gaussian mixture model

(5)

with equally weighted mixing coefficient .

This perspective immediately highlights a problem with the squared Euclidean norm, used by a number of approaches as shown in Fig. 4. The Euclidean norm, which corresponds to the squared Mahalanobis distance with , implicitly assumes each cluster is distributed according to a unit normal, as seen in Figure 9. By contrast, the squared Mahalanobis distance considers cluster covariance when computing distances to the cluster centers.

6 Experiments

We evaluate Simple CNAPS on the Meta-Dataset [triantafillou2019meta] family of datasets, demonstrating improvements compared to nine baseline methodologies including the current SoTA, CNAPS. Benchmark results reported come from [triantafillou2019meta, requeima2019fast].

Dataset

Meta-Dataset [triantafillou2019meta] is a benchmark for few-shot learning and image classification. It is the union of labeled image datasets: ILSVRC-2012 (ImageNet) [russakovsky2015imagenet], Omniglot [lake2015human], FGVC-Aircraft (Aircraft) [maji2013fine], CUB-200-2011 (Birds) [wah2011caltech], Describable Textures (DTD) [cimpoi2014describing], QuickDraw [jongejan2016quick], FGVCx Fungi (Fungi) [fungi2018schroeder], VGG Flower (Flower) [nilsback2008automated], Traffic Signs (Signs) [houben2013detection] and MSCOCO [lin2014microsoft]. In keeping with prior art we report results using the first 8 datasets for training, reserving Traffic Signs and MSCOCO for “out-of-domain” performance evaluation. Additionally, from the eight training datasets used for training, some classes are held out for testing, to evaluate “in-domain” performance. Following [requeima2019fast], we extend the out-of-domain evaluation with 3 more datasets: MNIST [lecun-mnisthandwrittendigit-2010], CIFAR10 [Krizhevsky09learningmultiple] and CIFAR100 [Krizhevsky09learningmultiple]. We report results using standard test/train splits provided by [triantafillou2019meta], but, importantly, we have cross-validated our critical empirical claims using different test/train splits and our results are robust across folds (see Appendix B).

6.1 Results

In-Domain Accuracy
Model ImageNet Omniglot Aircraft Birds DTD QuickDraw Fungi Flower
MAML 32.4±1.0 71.9±1.2 52.8±0.9 47.2±1.1 56.7±0.7 50.5±1.2 21.0±1.0 70.9±1.0
RelationNet 30.9±0.9 86.6±0.8 69.7±0.8 54.1±1.0 56.6±0.7 61.8±1.0 32.6±1.1 76.1±0.8
k-NN 38.6±0.9 74.6±1.1 65.0±0.8 66.4±0.9 63.6±0.8 44.9±1.1 37.1±1.1 83.5±0.6
MatchingNet 36.1±1.0 78.3±1.0 69.2±1.0 56.4±1.0 61.8±0.7 60.8±1.0 33.7±1.0 81.9±0.7
Finetune 43.1±1.1 71.1±1.4 72.0±1.1 59.8±1.2 69.1±0.9 47.1±1.2 38.2±1.0 85.3±0.7
ProtoNet 44.5±1.1 79.6±1.1 71.1±0.9 67.0±1.0 65.2±0.8 64.9±0.9 40.3±1.1 86.9±0.7
ProtoMAML 47.9±1.1 82.9±0.9 74.2±0.8 70.0±1.0 67.9±0.8 66.6±0.9 42.0±1.1 88.5±0.7
CNAPS 51.3±1.0 88.0±0.7 76.8±0.8 71.4±0.9 62.5±0.7 71.9±0.8 46.0±1.1 89.2±0.5
AR-CNAPS 52.3±1.0 88.4±0.7 80.5±0.6 72.2±0.9 58.3±0.7 72.5±0.8 47.4±1.0 86.0±0.5
Simple AR-CNAPS 56.5±1.1 91.1±0.6 81.8±0.8 74.3±0.9 72.8±0.7 75.2±0.8 45.6±1.0 90.3±0.5
Simple CNAPS 58.6±1.1 91.7±0.6 82.4±0.7 74.9±0.8 67.8±0.8 77.7±0.7 46.9±1.0 90.7±0.5
Table 1:

In-domain few-shot classification accuracy of Simple CNAPS and Simple AR-CNAPS compared to the baselines. All values are percentages and the error bars represent a 95% confidence interval over tasks.

Best performance(s) on each dataset is/are in bold, while values underlined are better than all baselines (excluding the non-corresponding CNAPS baseline) by statistically significant margins.
Out-of-Domain Accuracy Average Accuracy
Model Signs MSCOCO MNIST CIFAR10 CIFAR100 In-Domain Out-Domain Overall
MAML 34.2±1.3 24.1±1.1 NA NA NA 50.4±1.0 29.2±1.2 46.2±1.1
RelationNet 37.5±0.9 27.4±0.9 NA NA NA 58.6±0.9 32.5±0.9 53.3±0.9
k-NN 40.1±1.1 29.6±1.0 NA NA NA 59.2±0.9 34.9±1.1 54.3±0.9
MatchingNet 55.6±1.1 28.8±1.0 NA NA NA 59.8±0.9 42.2±1.1 56.3±1.0
Finetune 66.7±1.2 35.2±1.1 NA NA NA 60.7±1.1 51.0±1.2 58.8±1.1
ProtoNet 46.5±1.0 39.9±1.1 74.3±0.8 66.4±0.7 54.7±1.1 64.9±1.0 56.4±0.9 61.6±0.9
ProtoMAML 52.3±1.1 41.3±1.0 NA NA NA 67.5±0.9 46.8±1.1 63.4±0.9
CNAPS 60.1±0.9 42.3±1.0 88.6±0.5 60.0±0.8 48.1±1.0 69.6±0.8 59.8±0.8 65.9±0.8
AR-CNAPS 60.2±0.9 42.9±1.1 92.7±0.4 61.5±0.7 50.1±1.0 69.7±0.8 61.5±0.8 66.5±0.8
Simple AR-CNAPS 74.7±0.7 44.3±1.1 95.7±0.3 69.9±0.8 53.6±1.0 73.5±0.8 67.6±0.8 71.2±0.8
Simple CNAPS 73.5±0.7 46.2±1.1 93.9±0.4 74.3±0.7 60.5±1.0 73.8±0.8 69.7±0.8 72.2±0.8
Table 2: Middle) Out-of-domain few-shot classification accuracy of Simple CNAPS and Simple AR-CNAPS compared to the baselines. Right) In-domain, out-of-domain and overall mean classification accuracy of Simple CNAPS and Simple AR-CNAPS compared to the baselines. All values are percentages and the error bars represent a 95% confidence interval over tasks. Best performance(s) on each dataset is/are in bold, while values underlined are better than baselines (excluding the non-corresponding CNAPS baseline) by statistically significant margins.

In-domain performance

The in-domain results for Simple CNAPS and Simple AR-CNAPS are shown in Table 1. Simple AR-CNAPS has the sample Mahalanobis-based classifier as Simple CNAPS but uses the autoregressive feature extraction version of CNAPS: AR-CNAPS. Simple AR-CNAPS outperforms previous SoTA models on 7 out of the 8 datasets while matching past SoTA on FGVCx Fungi (Fungi). Simple CNAPS outperforms baselines on 6 out of 8 datasets while matching performance on FGVCx Fungi (Fungi) and Describable Textures (DTD). Overall, in-domain performance gains are considerable in the few-shot learning domain with 2-6% margins. Simple CNAPS achieves an average 73.8% accuracy on in-domain few-shot classification, a 4.2% gain over CNAPS, while Simple AR-CNAPS achieves 73.5% accuracy, a 3.8% gain over AR-CNAPS.

Out-of-domain performance

As shown in Table 2, Simple CNAPS and Simple AR-CNAPS produce substantial gains in classification accuracy on out-of-domain datasets, each exceeding the SoTA baseline. With an average out-of-domain accuracy of 69.7% and 67.6%, Simple CNAPS and Simple AR-CNAPS outperform SoTA by 8.2% and 7.8%. This means that Simple CNAPS/AR-CNAPS generalizes to out-of-domain datasets better than baseline models. Note that Simple AR-CNAPS under-performs Simple CNAPS, suggesting that the auto-regressive feature adaptation approach may overfit to the domain of datasets it has been trained on.

Overall performance

Overall, Simple CNAPS achieves the best classification accuracy at 72.2% with Simple AR-CNAPS trailing very closely at 71.2%. As the overall performance of the two variants are statistically indistinguishable, we recommend Simple CNAPS over Simple AR-CNAPS as it has fewer parameters.

Figure 10: Accuracy vs. Shots: Average number of support examples (in log scale) per class v/s accuracy. The accuracies of all classes within tasks in the test set (both in-domain and out-of-domain) are grouped together according to number of support examples per class. Each group of accuracies is then averaged to obtain a value for each shot. The squared euclidean variation has been referred to as L2.
Figure 11: Accuracy vs. Ways: Number of ways (classes in the task) v/s accuracy. Tasks in the test set are grouped together by number of classes. The accuracies are averaged to obtain a value for each count of class. Note that sqaured Euclidean has been referred to as L2.
In-Domain Accuracy
Metric ImageNet Omniglot Aircraft Birds DTD QuickDraw Fungi Flower
Negative Dot Product 48.0±1.1 83.5±0.9 73.7±0.8 69.0±1.0 66.3±0.6 66.5±0.9 39.7±1.1 88.6±0.5
Absolute Distance () 53.6±1.1 90.6±0.6 81.0±0.7 73.2±0.9 61.1±0.7 74.1±0.8 47.0±1.0 87.3±0.6
Squared Euclidean () 53.9±1.1 90.9±0.6 81.8±0.7 73.1±0.9 64.4±0.7 74.9±0.8 45.8±1.0 88.8±0.5
Simple CNAPS -TR 56.7±1.1 91.1±0.7 83.0±0.7 74.6±0.9 70.2±0.8 76.3±0.9 46.4±1.0 90.0±0.6
Simple CNAPS 58.6±1.1 91.7±0.6 82.4±0.7 74.9±0.8 67.8±0.8 77.7±0.7 46.9±1.0 90.7±0.5
Table 3: In-domain few-shot classification accuracy of Simple CNAPS compared to ablated alternatives of the negative dot product, absolute difference (), squared Euclidean () and removing task regularization () denoted by ”-TR”. All values are percentages and the error bars represent a 95% confidence interval over tasks. Best performance(s) on each dataset is/are in bold, while values underlined are better than all ablations by statistically significant margins.
Out-of-Domain Accuracy Average Accuracy
Metric Signs MSCOCO MNIST CIFAR10 CIFAR100 In-Domain Out-Domain Overall
Negative Dot Product 53.9±0.9 32.5±1.0 86.4±0.6 57.9±0.8 38.8±0.9 66.9±0.9 53.9±0.8 61.9±0.9
Absolute Distance () 66.4±0.8 44.7±1.0 88.0±0.5 70.0±0.8 57.9±1.0 71.0±0.8 65.4±0.8 68.8±0.8
Squared Euclidean () 68.5±0.7 43.4±1.0 91.6±0.5 70.5±0.7 57.3±1.0 71.7±0.8 66.3±0.8 69.6±0.8
Simple CNAPS -TR 74.1±0.6 46.9±1.1 94.8±0.4 73.0±0.8 59.2±1.0 73.5±0.8 69.6±0.8 72.0±0.8
Simple CNAPS 73.5±0.7 46.2±1.1 93.9±0.4 74.3±0.7 60.5±1.0 73.8±0.8 69.7±0.8 72.2±0.8
Table 4: Middle) Out-of-domain few-shot classification accuracy of Simple CNAPS compared to ablated alternatives of the negative dot product, absolute difference (), squared Euclidean () and removing task regularization () denoted by ”-TR”. Right) In-domain, out-of-domain and overall mean classification accuracies of the ablated models. All values are percentages and the error bars represent a 95% confidence interval over tasks. Best performance(s) on each dataset is/are in bold, while values underlined are better than all ablations by statistically significant margins.

Comparison with other distance metrics

To test the significance of our choice of Mahalanobis distance we substitute it within our architecture with other distance metrics - absolute difference (), squared Euclidean (), and negative dot-product. Performance comparisons for in-domain datasets are shown in Table 3, and out-of-domain datasets and overall performance in Table 4. We observe that using the Mahalanobis distance results in the best in-domain, out-of-domain, and overall average performance on all datasets.

Impact of the task regularizer

We also consider a variant of Simple CNAPS where all-classes-within-task covariance matrix is not included in the covariance regularization (denoted with the ”-TR” tag). As shown in Table 4, we observe that, while removing the task level regularizer only marginally reduces overall performance, the difference on individual datasets such as ImageNet can be large.

Sensitivity to the number of support examples per class

Figure 10 shows how the overall classification accuracy varies as a function of the average number of support examples per class (shots) over all tasks. We compare Simple CNAPS, original CNAPS, and the squared Euclidean variant of our method. As expected, the average number of support examples per class is highly correlated with the performance. All methods perform better with more labeled examples per support class, with Simple CNAPS performing substantially better as the number of shots increases. The surprising discovery is that Simple CNAPS is effective even when the number of labeled instances is as low as four, suggesting both that even poor estimates of the task and class specific covariance matrices are helpful and that the regularization scheme we have introduced works remarkably well.

Sensitivity to the number of classes in the task

In Figure 11, we examine average accuracy as a function of the number of classes in the task. We find that, irrespective of the number of classes in the task, we maintain accuracy improvement over both CNAPS and our squared Euclidean norm variant.

7 Discussion

Few shot learning is a fundamental task in modern AI research. In this paper we have introduced a new method for amortized few shot image classification which establishes a new SoTA performance benchmark by making a simplification to the current SoTA architecture. Our specific architectural choice, that of deterministically estimating and using Mahalanobis distances for classification of task-adjusted class-specific feature vectors seems to produce, via training, embeddings that generally allow for useful covariance estimates, even when the number of labeled instances per task and class is small. The effectiveness of the Mahalanobis distance in feature space for distinguishing classes suggests connections to hierarchical regularization schemes [pmlr-v27-salakhutdinov12a] that could enable performance improvements even in the zero-shot setting. In the future, exploration of other Bregman divergences can be an avenue of potentially fruitful research, while further work on metric learning from optimization and probabilistic standpoints can yield interesting results. Additional enhancements in form of data and task augmentation can also boost performance.

References

Appendix A (Simple) CNAPS in Details

a.1 Auto-Regressive CNAPS

Figure 12: Architectural overview of the feature extractor adaptation network : Figure has been adapted from [requeima2019fast] and showcases the neural architecture used for each adaptation module (corresponding to residual block ) in the feature extractor adaptation network .
Figure 13: Overview of the auto-regresive feature extractor adaptation in CNAPS: in addition to the structure shown in Figure 6, AR-CNAPS takes advantage of a series of pre-block set encoders to furthermore condition the output of each on the set representation . The set representation is formed by first adapting the previous blocks , then pushing the support set through the adapted blocks to form an auto-regressive adapted set representation at block . This way, adaptive functions later in the pipeline are more explicitly aware of the changes made by the previous adaptation networks, and can adjust better accordingly.

In [requeima2019fast], an additional auto-regressive variant for adapting the feature extractor is proposed, referred to as AR-CNAPS. As shown in Figure 13, AR-CNAPS extends CNAPS by introducing the block-level set encoder at each block . These set encoders use the output obtained by pushing the support through all previous blocks to form the block level set representation . This representation is then subsequently used as input to the adaptation network in addition to the task representation . This way the adaptation network is not just conditioned on the task, but is also aware of the potential changes in the previous blocks as a result of the adaptation being performed by the adaptation networks before it (i.e., ). The auto-regressive nature of AR-CNAPS allows for a more dynamic adaptation procedure that boosts performance in certain domains.

a.2 FiLM Layers

Proposed by [perez2018film], Feature-wise Linear Modulation (FiLM) layers were used for visual question answering, where the feature extractor could be conditioned on the question. As shown in Figure 14, these layers are inserted within residual blocks, where the feature channels are scaled and linearly shifted using the respective FiLM parameters and . This can be extremely powerful in transforming the extracted feature space. In our work and [requeima2019fast], these FiLM parameters are conditioned on the support images in the task . This way, the adapted feature extractor is able to modify the feature space to extract the features that allow classes in the task to be distinguished most distinctly. This is in particular very powerful when the classification metric is changed to the Mahalanobis distance, as with a new objective, the feature extractor adaptation network is able to learn to extract better features (see difference between with and without in Table 6 on CNAPS and Simple CNAPS).

Figure 14: Overview of FiLM Layers: Figure is from [requeima2019fast]. Left) FiLM layer operating a series of channels indexed by , scaling and shifting the feature channels as defined by the respective FiLM parameters and . Right) Placement of these FiLM modules within a ResNet18 [DBLP:journals/corr/HeZRS15-resnet] basic block.

a.3 Network Architectures

Figure 15: Overview of Neural Architectures used in (Simple) CNAPS: a) Auto-regressive set encoder . Note that since this is conditioned on the channel outputs of the convolutional filter, it’s not convolved any further. b) Task encoder that mean-pools convolutionally filtered support examples to produce the task representation. c) architectural overview of the classifier adaptation network

consisting of a three layer MLP with a residual connection. Three diagrams are based on Table E.8, Table E.9, and Table E.11 in

[requeima2019fast].
Classification Accuracy (%)
Model ILSVRC Omniglot Aircraft CUB DTD QuickDraw Fungi Flower
CNAPS 49.6±1.1 87.2±0.8 81.0±0.7 69.7±0.9 61.3±0.7 72.0±0.8 *32.2±1.0 *70.9±0.8
Simple CNAPS 55.6±1.1 90.9±0.8 82.2±0.7 75.4±0.9 74.3±0.7 75.5±0.8 *39.9±1.0 *88.0±0.8
CNAPS 50.3±1.1 86.5±0.8 77.1±0.7 71.6±0.9 *64.3±0.7 *33.5±0.9 46.4±1.1 84.0±0.6
Simple CNAPS 58.1±1.1 90.8±0.8 83.8±0.7 75.2±0.9 *74.6±0.7 *64.0±0.9 47.7±1.1 89.9±0.6
CNAPS 51.5±1.1 87.8±0.8 *38.2±0.8 *58.7±1.0 62.4±0.7 72.5±0.8 46.9±1.1 89.4±0.5
Simple CNAPS 56.0±1.1 91.1±0.8 *66.6±0.8 *68.0±1.0 71.3±0.7 76.1±0.8 45.6±1.1 90.7±0.5
CNAPS *42.4±0.9 *59.6±1.4 77.2±0.8 69.3±0.9 62.9±0.7 69.1±0.8 40.9±1.0 88.2±0.5
Simple CNAPS *49.1±0.9 *76.0±1.4 83.0±0.8 74.5±0.9 74.4±0.7 74.8±0.8 44.0±1.0 91.0±0.5
Table 5: Cross-validated classification accuracy results. Note that * denotes that this dataset was excluded from training, and therefore, signifies out-of-domain performance. Simple CNAPS values in bold indicate significant statistical gains over CNAPS.
Average Accuracy with (%) Average Accuracy without (%)
Metric/Model Variation In-Domain Out-Domain Overall In-Domain Out-Domain Overall
Negative Dot Product 66.9±0.9 53.9±0.8 61.9±0.9 38.4±1.0 44.7±1.0 40.8±1.0
CNAPS 69.6±0.8 59.8±0.8 65.9±0.8 54.4±1.0 55.7±0.9 54.9±0.9
Absolute Distance () 71.0±0.8 65.4±0.8 68.8±0.8 54.9±1.0 62.2±0.8 57.7±0.9
Squared Euclidean () 71.7±0.8 66.3±0.8 69.6±0.8 55.3±1.0 61.8±0.8 57.8±0.9
Simple CNAPS -TR 73.5±0.8 69.6±0.8 72.0±0.8 52.3±1.0 61.7±0.9 55.9±1.0
Simple CNAPS 73.8±0.8 69.7±0.8 72.2±0.8 56.0±1.0 64.8±0.8 59.3±0.9
Table 6: Comparing in-domain, out-of-domain and overall accuracy averages of each metric/model variation when feature extractor adaptation is performed (denoted as ”with ”) vs. when no adaptation is performed (denoted as ”without ”). Values in bold signify best performance in the column while underlined values signify superior performance of Simple CNAPS (and the -TR variation) compared to the CNAPS baseline.
Average Classification Accuracy (%)
Fold Model In-Domain Out-Domain Overall
1 CNAPS 70.1±0.4 51.6±0.4 65.5±0.4
1 S. CNAPS 75.7±0.3 64.0±0.4 72.7±0.3
2 CNAPS 69.3±0.4 48.9±0.3 64.2±0.4
2 S. CNAPS 74.3±0.4 69.3±0.4 73.0±0.3
3 CNAPS 68.4±0.4 48.5±0.4 63.4±0.4
3 S. CNAPS 71.8±0.4 67.3±0.5 70.7±0.4
4 CNAPS 67.9±0.3 51.0±0.7 63.7±0.4
4 S. CNAPS 73.6±0.3 62.6±0.6 70.9±0.4
Avg CNAPS 69.0±1.4 50.0±1.8 64.2±1.6
Avg S. CNAPS 73.8±1.3 65.8±1.8 71.8±1.4
Table 7: Cross-validated in-domain, out-of-domain and overall classification accuracies averaged across each fold and combined. Note that for conciseness of the table, Simple CNAPS has been shortened to ”S. CNAPS”. Simple CNAPS values in bold indicate statistically significant gains over CNAPS.

We adapt the same architectural choices for the task encoder , auto-regressive set encoders and the feature extractor adaptation network as [requeima2019fast]. The neural architecture for each adaptation module inside of has been shown in Figure 12. The neural configurations for the task encoder and the auto-regressive set encoders used in AR-CNAPS are shown in Figure 15-a and Figure 15-b respectively. Note that for the auto-regressive set encoders, there is no need for convolutional layers. The input to these networks come from the output of the corresponding residual block adapted to that level (denoted by for block ) which has already been processed with convolutional filters.

Unlike CNAPS, we do not use the classifier adaptation network . As shown in Figure 15-c, the classification weights adaptor

consists of an MLP consisting of three fully connected (FC) layers with the intermediary none-linearity ELU, which is the continuous approximation to ReLU as defined below:

(6)

As mentioned previously, without the need to learn the three FC layers in , Simple CNAPS has 788,485 fewer parameters while outperforming CNAPS by considerable margins.

Appendix B Cross Validation

The Meta-Dataset [triantafillou2019meta] and its 8 in-domain 2 out-of-domain split is a setting that has defined the benchmark for the baseline results provided. The splits, between the datasets, were intended to capture an extensive set of visual domains for evaluating the models.

However, despite the fact that all past work directly rely on the provided set up, we go further by verifying that our model is not overfitting to the proposed splits and is able to consistently outperform the baseline with different permutations of the datasets. We examine this through a 4-fold cross validation of Simple CNAPS and CNAPS on the following 8 datasets: ILSVRC-2012 (ImageNet) [russakovsky2015imagenet], Omniglot [lake2015human], FGVC-Aircraft [maji2013fine], CUB-200-2011 (Birds) [wah2011caltech], Describable Textures (DTD) [cimpoi2014describing], QuickDraw [jongejan2016quick], FGVCx Fungi [fungi2018schroeder] and VGG Flower [nilsback2008automated]. During each fold, two of the datasets are exluded from training, and both Simple CNAPS and CNAPS are trained and evaluated in that setting.

As shown by the classification results in Table 5, in all four folds of validation, Simple CNAPS is able to outperform CNAPS on 7-8 out of the 8 datasets. The in-domain, out-of-domain, and overall averages for each fold noted in Table 7 also show Simple CNAPS’s accuracy gains over CNAPS with substantial margins. In fact, the fewer number of in-domain datasets in the cross-validation (6 vs. 8) actually leads to wider gaps between Simple CNAPS and CNAPS. This suggests Simple CNAPS is a more powerful alternative in the low domain setting. Furthermore, using these results, we illustrate that our gains are not specific to the Meta-Dataset setup.

Appendix C Ablation study of the Feature Extractor Adaptation Network

In addition to the choice of metric ablation study referenced in Section 6.1, we examine the behaviour of the model when the feature extractor adaptation network has been turned off. In such setting, the feature extractor would only consist of the pre-trained ResNet18 [DBLP:journals/corr/HeZRS15-resnet] . Consistent to [requeima2019fast], we refer to this setting as ”No Adaptation” (or “No Adapt” for short). We compare the “No Adapt” variation to the feature extractor adaptive case for each of the metrics/model variations examined in Section 6.1. The in-domain, out-of-domain and overall classification accuracies are shown in Table 6. As shown, without all models lose approximately 15, 5, and 12 percentage points across in-domain, out-of-domain and overall accuracy, while Simple CNAPS continues to hold the lead especially in out-of-domain classification accuracy. It’s interesting to note that without the task specific regularization term (denoted as ”-TR”), there’s a considerable performance drop in the “No Adaptation” setting; while when the feature extractor adaptation network is present, the difference is marginal. This signifies two important observations. First, it shows the importance of of learning the feature extractor adaptation module end-to-end with the Mahalanobis distance, as it’s able adapt the feature space best suited for using the squared Mahalanobis distance. Second, the adaptation function can reduce the importance of the task regularizer by properly de-correlating and normalizing variance within the feature vectors. However, where this is not possible, as in the “No Adaptation” case, the all-classes-task-level covariance estimate as an added regularizer in Equation 2 becomes crucial in maintaining superior performance.

Appendix D Projection Networks

We additionally explored metric learning where in addition to changing the distance metric, we considered projecting each support feature vector and query vector to a new decision space where then squared Mahalanobis distance was to be used for classification. Specifically, we trained a projection network such that for Equations 2 and 3, , and were calculated based on the projected feature vectors as oppose to the feature vector set . Similarly, the projected query feature vector was used for classifying the query example as oppose to the bare feature vector used within Simple CNAPS. We define in our experiments to be the following:

(7)

where ELU, a continuous approximation to ReLU as previously noted, is used as the choice of non-linearity and , and are learned parameters.

Average Classification Accuracy (%)
Model In-Domain Out-Domain Overall
Simple CNAPS +P 72.4±0.9 67.1±0.8 70.4±0.8
Simple CNAPS 73.8±0.8 69.7±0.8 72.2±0.8
Table 8: Comparing the in-domain, out-of-domain and overall classification accuracy of Simple CNAPS +P (with projection networks) to Simple CNAPS. Values in bold show the statistically significant best result.

We refer to this variation of our model as “Simple CNAPS +P” with the “+P” tag signifying the addition of the projection function . The results for this variation of Simple CNAPS are compared to the base Simple CNAPS in Table 8. As shown, the projection network generally results in lower performance, although not to statistically significant degrees in in-domain and overall accuracies. Where the addition of the projection network results in substantial loss of performance is in the out-of-domain setting with Simple CNAPS +P’s average accuracy of 67.1±0.8 compared to 69.7±0.8 for the Simple CNAPS. We hypothesize the significant loss in out-of-domain performance to be due to the projection network overfitting to the in-domain datasets.