Universal Representation Learning from Multiple Domains for Few-shot Classification

03/25/2021 ∙ by Wei-Hong Li, et al. ∙ 22

In this paper, we look at the problem of few-shot classification that aims to learn a classifier for previously unseen classes and domains from few labeled samples. Recent methods use adaptation networks for aligning their features to new domains or select the relevant features from multiple domain-specific feature extractors. In this work, we propose to learn a single set of universal deep representations by distilling knowledge of multiple separately trained networks after co-aligning their features with the help of adapters and centered kernel alignment. We show that the universal representations can be further refined for previously unseen domains by an efficient adaptation step in a similar spirit to distance learning methods. We rigorously evaluate our model in the recent Meta-Dataset benchmark and demonstrate that it significantly outperforms the previous methods while being more efficient. Our code will be available at https://github.com/VICO-UoE/URL.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 8

page 14

page 15

page 16

page 17

page 18

page 19

page 20

Code Repositories

URL

Universal Representation Learning from Multiple Domains for Few-shot Classification


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Figure 1: Universal Representation Learning (URL). To learn universal representations from multiple domains that can generalize to previously unseen domains, one strategy [13, 29] is to learn one feature extractor for each domain and learn to retrieve or combine feature extractors for the target task during meta-test stage as in (a). We propose a universal representation network (b) which is learned by distilling knowledge learned from multiple datasets to one single feature extractor

shared across all domains. In meta-test stage, we use a linear transformation

that further refines the universal representations for better generalization to unseen domains. Our universal representation network achieves better generalization performance than using multiple domain-specific ones while being more efficient than (a).

As deep neural networks progress to dramatically improve results in most of standard computer vision tasks, there is a growing community interest for more ambitious goals. One of them is to improve the data efficiency of the standard supervised methods that rely on large amount of expensive and time-consuming hand-labeled data. Just like the human intelligence is capable of learning concepts from few labeled samples,

few-shot learning [24, 33] aims at adapting a classifier to accommodate new classes not seen in training, given a few labeled samples from these classes.

Earlier works in few-shot learning focus on evaluating their methods in homogeneous learning tasks, e.g. Ominglot [25], miniImageNet [53], tieredImageNet [43], where both the meta-train and meta-test examples are sampled from a single data distribution (or dataset). Recently, the interest of the community has shifted to a more realistic and challenging experimental setting, where the goal is to learn few-shot models that can generalize not only within a single data distribution but also to previously unseen data distributions. To this end, Triantafillou et al. [52] propose a new heterogeneous benchmark, Meta-Dataset that consists of ten datasets from different domains for meta-training and meta-test. While, initially two domains were kept as unseen domains, later three more unseen domains are included to meta-test the generalization ability of learned models.

While the few-shot methods [14, 48, 49, 53], which were proposed before Meta-Dataset was available, can be directly applied to this new benchmark with minor modifications, they fail to cope with domain gap between train and test datasets and thus obtain subpar performance on Meta-Dataset. Recently several few-shot learning methods are proposed to address this challenge, which can be coarsely grouped into two categories, adaptation [2, 44]

and feature selection based methods 

[13, 29]. CNAPS [44] consists of an adaptation network that modulates the parameters of both a feature extractor and classifier for new categories by encoding the data distribution of few training samples. Simple CNAPS [2] extends CNAPS by replacing its parametric classifier with a non-parametric classifier based on Mahalanobis distance and shows that adapting the classifier from few samples is not necessary for good performance. SUR [13] and URT [29] further show that adaptation for the feature extractor can also be replaced by a feature selection mechanism. In particular, both methods [13, 29] learn a separate deep network for each training dataset in an offline stage, employ them to extract multiple features for each image, and then select the optimal set of features either based on a similarity measure [13] or on learning an attention mechanism [29]. However, despite their good performance, SUR and URT are computationally expensive and require multiple forward passes through multiple networks during inference time.

In this work, we propose an efficient and high performance few-shot method based on multi-domain learning. Like [13, 29], our method builds on multi-domain representations that are learned in an offline stage. However, we learn a single set of universal representations (a single deep neural network) over multiple domains which has a fixed computational cost regardless of the number of domains at inference unlike them. Similar to the adaptation based techniques [2, 44], our method further employs a simple adaptation strategy to learn the domain specific representations from few samples (an illustration in Fig. 1).

In particular, we propose to distill the knowledge from multiple domains to a single model, which can efficiently leverage useful information from multiple diverse domains. Learning multi-domain representations is a challenging task and requires to leverage commonalities in the domains while minimizing interference (negative transfer) (e.g. [8, 41, 56]) between them. To mitigate this, we align the intermediate representations of our multi-domain network with the ones of the domain-specific networks after carefully aligning each space by using small task-specific adapters and Centered Kernel Alignment (CKA) [22]. Finally, inspired from the use of Mahalanobis distance in [2], we adapt the learned multi-domain features into the new task by mapping them into a task-specific space. However, unlike [2], we learn the parameters of this mapping via adaptation in a discriminative way. We rigorously evaluate our method in Meta-Dataset benchmark and show that our method outperforms the state-of-the-art few-shot methods significantly in both seen and unseen domain generalization.

2 Related Work

Meta-learning based few-shot classification.

One approach that directly trains a model to perform few-shot classification end-to-end is meta-learning. Meta-learning approaches for few-shot learning can be broadly divided into two groups, metric-based and optimization-based approaches. The key idea in the former group is to map raw images to vector representations and use nearest neighbor classifiers with different distance functions by learning discriminative feature spaces with Siamese networks 

[21], producing a weighted nearest neighbor classifier [53], representing each class with the average of the samples in the support set [48]. The latter group focuses on learning models that can quickly adapt to new tasks from few samples in support. The successful methods include MAML [14] that poses learning to learn problem in a bi-level optimization where the weights of the network are modeled as a function of the initial network weights, Reptile  [35] that alleviates the expensive second order derivative computation in MAML by a first order approximation, MAML++ [1] that introduces multiple speed and stability improvements over MAML.

Transfer learning based few-shot classification.

There are also simple yet effective methods [6, 7, 11] that first learn a neural network on all the available training data and transfer it to few-shot tasks in test time. Baseline++ [6] only updates a parametric classifier with cosine distance, while Meta-Baseline [7]

fine-tunes entire network with a nearest-centroid cosine similarity and a scale parameter. Dhillon 

et al. [11] explore fine-tuning in a transductive setting, where the query set is assumed to be available at the same time.

Cross-domain few-shot classification.

Recent few-shot techniques [5, 13, 29, 44] focus on few-shot learning that generalizes to unseen domains at test time in the recently proposed Meta-Dataset [52]. CNAPS [44] adapts the parameters of feature encoder and classifier by conditioning them on current input task via FiLM layers [39] which is further extended in Simple CNAPS [2] adopts a non-parametric classifier using a simple class-covariance-based distance metric, namely the Mahalanobis distance. In contrast SUR [13] stores the domain-specific knowledge by learning an independent feature extractor for each domain, and automatically selects the most relevant representations for a new task by linearly combining features from domain-specific features. URT [29] instead meta-learns the feature selection mechanism for new tasks by using Transformer layers. Like SUR and URT, our method uses multi-domain features but in a more efficient way, by learning a single network over multiple domains. Our method requires significantly less network capacity and compute load than theirs. In addition, similar to Simple CNAPS [2], we map our features to a task-specific space before applying the nearest neighbor classifier but we learn the parameters of this mapping from each support set.

Knowledge distillation.

Our work is related to knowledge distillation (KD) methods [17, 27, 30, 40, 45, 50] that distills the knowledge of an ensemble of large teacher models to a small student neural network at the classifier [17] and intermediate layers [45]. Born-Again Neural Networks [15] uses KD proposes to consecutively distill knowledge from an identical teacher network to a student network, which is further applied to few-shot learning in [51] and multi-task learning in [10]. Most similar to our work, Li and Bilen [27] apply knowledge distillation to align features of a student multi-task network to multiple single-task learning networks by introducing task-specific adapters. While we use task-specific adapters to align the features across multiple networks like [27], we apply the alignment to a more challenging setting of multi-domain learning where there are substantial gap between different domains unlike their method that is shown to work in multi-task learning where multiple tasks are sampled from a single data distribution. To this end, we incorporate a more effective feature matching loss inspired from Centered Kernel Alignment (CKA) to align features in presence of large domain gap.

Universal representation.

A representation that works equally well in multiple domain, termed universal representation, is introduced in [3]. To learn a universal representation in multiple domains, SUR [13] and URT [29] propose to learn an independent model for each domain and learn to retrieve or blend appropriate models for a new task in few-shot classification. Alternatively, [3, 41, 42] propose to learn a single network to perform image classification on very different domains by sharing a large majority of parameters across domains and encoding domain-specific information via normalization layers [3], light-weight residual adapters [41, 42], Feature-wise Linear Modulate (FiLM) [39]. Our method is inspired from these methods, thus we learn universal representations without any domain-specific weights and use them in few-shot learning.

3 Method

In this section, we describe the problem setting, introduce our method in two parts, multi-domain feature learning and feature adaptation.

3.1 Few-shot Task Formulation

Few-shot classification aims at learning to classify samples from a small training set with only few samples for each class. The task contains two sets of images: a support set that contains image and label pairs respectively that define the classification task and a query set that contains samples to be classified. In words, we would like to learn a classifier on the support set that can accurately predict the labels of the query set.

As in [13, 29], we solve this problem in two steps: i) a meta-training step where a learning algorithm receives a large dataset and outputs a general feature extractor , ii) a meta-test step where the target tasks are sampled from another large dataset by taking the subsets of the dataset to build and . Note that and contain mutually exclusive classes.

3.2 Learning multiple domain representations

Our focus is to learn few-shot image classification that generalizes not only within previously seen visual domains but also to unseen ones. As it is challenging to obtain the domain-specific knowledge from only few samples in a previously unseen domain, inspired from [3, 41] we hypothesize that using domain-agnostic or universal representations is the key to the success of cross-domain generalization. To this end, we propose learning a multi-domain network that works well for all the domain-specific tasks simultaneously and use this network as a feature extractor for the target tasks.

Let assume that consists of subdatasets, each sampled from a different domain. One potential solution is train a multi-domain network by jointly optimizing its parameters over the images from all domains (datasets):

(1)

where is cross-entropy loss, is a multi-domain feature extractor that takes an image as input and outputs a dimensional feature and is parameterized by a single set of parameters which is shared across domains. is a domain-specific classifier that takes in

and outputs a probability vector over the target categories and it is parameterized by

. While minimizing Eq. 1 results in a multi-domain feature extractor , several previous works report that this optimization is problematic due to the interference between the different tasks [8, 56], varying dataset sizes and difficulty [20, 27] and often leads to subpar results compared to individual single-domain networks.

Figure 2: Illustration of our proposed method for multi-domain feature learning. Given training images from different domains, we first train domain-specific networks and their classifiers

, freeze their weights and distill their knowledge to our multi-domain network by matching their features and predictions through two loss functions

and respectively. As matching multiple features is challenging, we co-align all the features by using light-weight adaptors and centered kernel alignment.

Motivated by this challenge, we propose a two stage procedure to learn multi-domain representations, inspired by the previous distillation methods [17, 27]. To this end, we first train domain-specific deep networks where each consists of a specific feature extractor and classifier with parameters and respectively, similarly to [13, 29]. However, instead of using domain-specific feature extractors and select the most relevant feature like them, we propose to learn a single multi-domain network that performs well in domains by distilling the knowledge of pretrained feature extractors. This has two key advantages over [13, 29]. First using a single feature extractor, which has the same capacity with each domain-specific one, is significantly more efficient in terms of run-time and number of parameters in the meta-test stage. Second learning to find the most relevant features for a given support and query set in [29] is not trivial and may also suffer from overfitting to the small number of datasets in the training set, while the multi-domain representations automatically contain the required information from the relevant domains.

In the second stage, we freeze the pretrained domain-specific feature extractors and transfer their knowledge into the multi-domain model at train time. Knowledge distillation can be performed at the prediction [17] and feature level [27, 45] by minimizing the distance between (i) the predictions of the multi-domain and corresponding single-domain network, and also between (ii) the multi-domain and single-domain features for given training samples. While Kullback-Leibler (KL) divergence is the standard choice for the predictions in [17], matching the multi-domain features to multiple single-domain ones simultaneously is an ill-posed problem, as the domain-specific features for a given image can vary the multi-domain network significantly across different domains. To this end, as in [27], we propose to map each domain specific feature into a common space by using adaptors with parameters and jointly train them along with the parameters of the multi-domain network:

(2)

where is KL divergence on network predictions, is a distance function in the feature space, and are their domain-specific weights. We illustrate this key idea in Fig. 2. In words, the multi-domain network is optimized to match the domain-specific features up to a transformation (i.e. ) and predict the ground-truth classes .

While Li and Bilen [27] show that L2 distance is effective to match the features across task-agnostic and task-specific networks, which are trained for different tasks on a single domain, here we argue that learning to match features that are trained on substantially diverse domains require better a more complex distance distance function to model non-linear correlations between the representations. To this end, inspired from [22], we propose to adopt the Centered Kernel Alignment (CKA) [22] similarity index with the rbf kernel that is shown to be capable of meaningful non-linear similarities between representations of higher dimension than the number of data points.

Next we briefly describe CKA. Suppose and denote the features that are computed by the multi-domain and domain-specific networks respectively for a given set of images

. We first compute the Radial Basis Function kernel matrices

and of and respectively. Then we use two kernel matrices and to measure the dissimilarity of and as following:

(3)

where tr and denote the trace of a matrix and centering matrix respectively, the second term is the CKA similarity between the multi-domain and domain-specific features. As the original CKA similarity requires the computation of the kernel matrices over the whole datasets, which is not scalable to large datasets, we follow [34] and compute them over each minibatch in our training. We refer to [22, 34] for more details.

3.3 Feature adaptation in meta-test

During meta-test, given a support set of a new learning task, we use the multi-domain model to extract features and adapt them to the target task. To this end, we apply a linear transformation with learnable parameters to the computed features, i.e. where . Then we follow a similar pipeline to the one in [13, 32, 48] to build a centroid classifier by averaging the embeddings belonging to this class:

(4)

where

is the number of classes in the support set. Next we estimate the likelihood of a support sample

by:

(5)

where is the negative cosine similarity.

We then optimize to minimize the following objective on the support set :

(6)

Solving Eq. 6 for results in high intra-class and low inter-class similarity in the adapted space. We then use and Eq. 5 to predict the label of the query sample from by picking the closest centroid . Our meta-test pipeline is illustrated Fig. 3.

Discussion.

In [2], Simple CNAPS uses the (squared) Mahalanobis distance between the features of class centroid and a query image, where is a covariance matrix specific to the task and class and is the class centroid in the feature space (before the adaptation). The authors show that considering the class covariance enables better adaptation of the feature extractor to the target task. Our adaptation strategy can be seen as a generalization of the Mahalanobis distance computation. Alternatively, assuming that can be decomposed into a product of a lower triangular matrix and its conjugate transpose, i.e. , one can first pre-transform the features by multiplication, i.e. and then compute the distance between these features and centroids. Similarly, we apply a linear transformation to the features but unlike [2], we learn its parameters by optimizing Eq. 6.

Figure 3: Illustration of adaptation procedure in meta-test. Given a support set and query image, our method learns to map their features to a task-specific space through a linear transformation and assign the query image to the nearest class center.

4 Experiments

Test Dataset Proto-MAML [52] BOHB-E [47] CNAPS [44] Best SDL MDL Simple CNAPS [2] SUR [13] URT [29] Ours
ImageNet
Omniglot
Aircraft
Birds
Textures
Quick Draw
Fungi
VGG Flower
Traffic Sign
MSCOCO
MNIST - - -
CIFAR-10 - - -
CIFAR-100 - - -
Average Rank 7.8 8.1 6.6 4.8 4.6 5.2 5.0 4.4 1.3
Table 1:

Comparison to baselines and state-of-the-art methods on Meta-Dataset. Mean accuracy, 95% confidence interval are reported. The first eight datasets are seen during training and the last five datasets are unseen and used for test only. Average rank is computed according to first 10 datasets as some methods do not report results on last three datasets.

Here we first describe the benchmarks, implementation details and competing methods. Then we rigorously compare our method to the state-of-the-art and also study each proposed component in an ablation. We also analyze our method qualitatively. Finally we evaluate our method in a global retrieval task to further evaluate the learned feature representations in few-shot classification task.

4.1 Experimental setup

Dataset.

Meta-Dataset [52] is a few-shot classification benchmark that initially consisted of ten datasets: ILSVRC_2012 [46]

(ImageNet), Omniglot 

[25]

, FGVC-Aircraft 

[31]

(Aircraft), CUB-200-2011 

[54] (Birds), Describable Textures [9] (DTD), QuickDraw [19], FGVCx Fungi [4] (Fungi), VGG Flower [36] (Flower), Traffic Signs [18] and MSCOCO [28] then further expanded with MNIST [26], CIFAR-10 [23] and CIFAR-100 [23]. We follow the standard procedure and use the first eight datasets for meta-training, in which each dataset is further divided into train, validation and test set with disjoint classes. The evaluation within these datasets is used to measure the generalization ability in the seen domains. The rest five datasets are reserved as unseen domain for meta-test for measuring the cross-domain generalization ability.

Implementation details.

We use PyTorch 

[38] library to implement our method. In all experiments we build our method on ResNet-18 [16] backbone for both single-domain and multi-domain networks. In the multi-domain network, we share all the layers but the last classifier across the domains. For training single-domain models, we strictly follow the training protocol in [13]

, use a SGD optimizer with a momentum and the cosine annealing learning scheduler with the same hyperparameters. For our multi-domain network, we use the same optimizer and scheduler as before, train it for 240,000 iterations. We set

and as 4 for ImageNet and 1 for other datasets and use early-stopping based on cross-validation over the validations sets of 8 training datasets. We refer to supplementary for more details.

Baselines and compared methods.

First we compare our method to our own baselines, i) the best single-domain model (Best SDL) where we use each single-domain network as the feature extractor and test it for few-shot classification in each dataset and pick the best performing model. (See supplementary for the complete results) This involves evaluating 8 single-domain networks on 13 datasets, serves a very competitive baseline, ii) the vanilla multi-domain learning baseline (MDL) that is learning by optimizing Eq. 1 without the proposed distillation method. As additional baseline, we include the best performing method in [52], i.e. Proto-MAML [52], and as well as the state-of-the-art methods, OHB-E [47], CNAPS [44], SUR [13], URT [29], and the Simple CNAPS [2]111Results of Proto-MAML [52], BOHB-E [47], and CNAPS [44] are obtained from Meta-Dataset.. For evaluation, we follow the standard protocol in [52], randomly sample 600 tasks for each dataset, and report average accuracy and 95% confidence score in all experiments. We reproduce results by training and evaluating SUR [13], URT [29], and Simple CNAPS [2] using their code for fair comparison as recommended by Meta-Dataset.

Varying-Way Five-Shot Five-Way One-Shot
Test Dataset Simple SUR URT Ours Simple SUR URT Ours
CNAPS [2] [13] [29] CNAPS [2] [13] [29]
ImageNet
Omniglot
Aircraft
Birds
Textures
Quick Draw
Fungi
VGG Flower
Traffic Sign
MSCOCO
MNIST
CIFAR-10
CIFAR-100
Average Rank 3.0 3.0 2.5 1.5 2.8 3.5 2.3 1.3
Table 2: Results of Varying-Way Five-Shot and Five-Way One-Shot settings. Mean accuracies are reported and the results with confidence interval are shown in the supplementary.

4.2 Results

As in Meta-Dataset [52], we sample each task with varying number of ways and shots and report the results in Table 1. Our method outperforms the state-of-the-art methods in seven out of eight seen datasets and four out of five unseen datasets. We also compute average rank as recommended in [52], our method ranks 1.3 in average and the state-of-the-art methods SUR and URT rank 5.0 and 4.4, respectively. More specifically, we obtain significant better results than the second best approach on Aircraft (+2.8), Birds (+2.1), Texture (+4.2), and VGG Flower (+1.5) for seen domains and Traffic Sign (+6.1)222The accuracy of all methods on Traffic Sign is different from the one in the original papers as one bug has been fixed in Meta-Dataset repository. See https://github.com/google-research/meta-dataset/issues/54 for more details. and MSCOCO (+3.8). The results show that jointly learning a single set of representations provides better generalization ability than fusing the ones from multiple single-domain feature extractors as done in SUR and URT. Notably, our method requires less parameters and less computations to run during inference than SUR and URT, as it runs only one universal network to extract features, while both SUR and URT need to pass the query set to multiple single-domain network.

Test Dataset L2 COSINE CKA KL CKA + KL
ImageNet
Omniglot
Aircraft
Birds
Textures
Quick Draw
Fungi
VGG Flower
Traffic Sign
MSCOCO
MNIST
CIFAR-10
CIFAR-100
Table 3: Comparison of loss functions for knowledge distillation. Mean accuracy, 95% confidence interval are reported. L2 denotes L2 loss between two feature representations. COSINE represents negative cosine similarity function. KL means KL divergence loss function on the network predictions. All results are obtained with feature adaptation during meta-test stage.

We also see that our method outperforms two strong baselines, Best SDL and MDL in all datasets except in QuickDraw. This indicates that i) universal representations are superior to the single-domain ones while generalizing to new tasks in both seen and unseen domains, while requiring significantly less number of parameters (1 vs 8 neural networks), ii) our distillation strategy is essential to obtain good multi-domain representations. While MDL outperforms the best SDL in certain domains by transferring representations across them, its performance is lower in other domains than SDL, possibly due to negative transfer across the significantly diverse domains. Surprisingly, MDL achieves the third best in average rank, indicating the benefit of multi-domain representations.

Test Dataset NCC NCC+MD LR SVM Ours
ImageNet
Omniglot
Aircraft
Birds
Textures
Quick Draw
Fungi
VGG Flower
Traffic Sign
MSCOCO
MNIST
CIFAR-10
CIFAR-100
Table 4:

Comparison of different classifiers that are incorporated to our method during meta-test stage. NCC, MD, LR, SVM denote nearest center classifier, Mahalanobis distance, logistic regression, support vector machines respectively.

Test Dataset ImageNet Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Sign MSCOCO MNIST CIFAR-10 CIFAR-100
Recall@ 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2 1 2
Sum
Concate
MDL
Simple CNAPS [2]
Ours
Table 5: Global retrieval performance on Meta-Dataset. In addition to few-shot learning experiments, we evaluate our method in a non-episodic retrieval task to further compare the generalization ability of our universal representations.
Figure 4: Qualitative comparison to URT in four datasets. Green and red colors indicate correct and false predictions respectively.

4.3 Further results

Varying-way five-shot setting.

After reporting results in a broad range of varying shots (e.g. up to 100 shots in some extreme cases), we further analyze our method for 5-shot setting with varying number of categories. To this end, we follow the setting in [12], compare our method to the best three state-of-the-art methods including Simple CNAPS, SUR and URT. In this setting, we sample a varying number of ways in Meta-Dataset the same as the standard setting but a fixed number of shots to form balanced support and query sets. As shown in Table 2, overall performance for all methods decreases in most datasets compared to results in Table 1 indicating that this is a more challenging setting. It is due to that five-shot setting samples much less support images than the standard setting. The ranking of different methods change slightly. The top-2 methods remain the same, while both Simple CNAPS and SUR obtain 3.0 average rank. SUR performs the best on MNIST, Simple CNAPS outperforms others on CIFAR-100 and URT is top-1 on Quick Draw. Ours still achieves significant better performance than other methods on the rest ten datasets.

Results in five-way one-shot setting.

Next we test an extremely challenging five-way one-shot setting on Meta-Dataset. For each task, only one image per class is seen as support set. This setting is often used in evaluating different methods in a single domain [25, 43, 53], while we adopt it for multiple domains. As shown in Table 2, our method achieves consistent gain as observed in previous two settings, which validates the importance of good universal representations in case of limited labeled samples in meta-test. Interestingly, Simple CNAPS achieves better rank than SUR in this setting, which is opposite in previous settings.

4.4 Analyses

Here we conduct an ablation study on different components in our framework by varying the loss function for the distillation, classifier type in meta-test.

Different distillation loss functions.

First we study different distillation loss functions, including L2 loss, cosine distance, KL divergence and CKA for learning the multi-domain networks and report their performances in Table 3

. While we apply KL divergence loss to match the logits of single and multi-domain networks as in

[17], the other loss functions are used to match the internal representations (features that are fed into classifiers) between those models. Among the individual loss functions, the best results are obtained with either our model with CKA or KL divergence loss, while CKA outperforms KL divergence in the most domains. Although the features are first aligned with an adapter, L2 and cosine loss functions are not sufficient to match features from very diverse domains and further aligning features with CKA is crucial. Note that here L2 baselines corresponds to the method of [27]. Finally, combining CKA with KL divergence gives the best performance over the multi-domain models that are trained with the individual loss functions.

Different classifiers in meta-test.

Next we evaluate the proposed adaptive mapping strategy with the nearest neighbor classifier (NCC), described in Section 3.3, to different parametric including Support Vector Machines (SVM), Logistic Regression (LR) as in [51] and non-parametric classifiers including NCC without the adaptive mapping and NCC with Mahalanobis Distance (NCC+MD) in [2] in Table 4. For non-parametric classifiers, NCC performs best in unseen domains when used with Mahalanobis distance. The parametric classifiers, SVM and LR that are trained on the limited support set obtain very competitive results and outperform the non-parametric ones in most domains. Our method, which combines the benefit of parametric and non-parametric classifiers, outperforms SVM, LR and NCC+MD in most seen datasets, while achieves worse in some unseen domains like Traffic Sign and MNIST.

Qualitative results.

We qualitatively analyze our method and compare it to URT [29] in Fig. 4 by illustrating the nearest neighbors in four different datasets given a query image (see supplementary for more examples). It is clear that our method produces more correct neighbors than URT. URT retrieves images with more similar colors, shapes and backgrounds, while our method is able to retrieve semantically similar images. It again suggests that our method is able to learn more useful and general representations.

4.5 Global retrieval

Here we go beyond the few-shot classification experiments and evaluate the generalization ability of our representations that are learned in the multi-domain network in a retrieval task, inspired from metric learning literature [37, 55]

. To this end, for each test image, we find the nearest images in entire test set in the feature space and test whether they correspond to the same category. For evaluation metric, we use Recall@

which considers the predictions with one of the closest neighbors with the same label as positive. In Table 5, we compare our method with Simple CNAPS in Recall@1 and Recall@2 (see supplementary for more results). URT and SUR require adaption using support set and no such adaptation in retrieval task is possible, we replace them with two baselines that concatenate or sum features from multiple domain-specific networks. Our method achieves the best performance in ten out of thirteen domains with significant gains in Aircraft, Birds, Textures and Fungi. This strongly suggests that our multi-domain representations are the key to the success of our method in the previous few-shot classification tasks.

5 Conclusion

In this work, we demonstrate that learning a single set of universal representations integrated with a feature refining step achieves state-of-the-art performance in the recent Meta-Dataset benchmark. To this end, we propose to optimize the parameters of a deep neural network simultaneously over multiple domains by aligning its features with multiple single-domain networks through linear adapters and a loss function that is inspired from CKA. We show that the universal features can be further refined from few examples to unseen tasks by learning a transformation in a similar spirit to distance learning. Our method outperforms the state-of-the-art techniques while using less number of parameters and being more computationally efficient than other multi-domain techniques.

References

  • [1] Antreas Antoniou, Harrison Edwards, and Amos Storkey. How to train your maml. In ICLR, 2019.
  • [2] Peyman Bateni, Raghav Goyal, Vaden Masrani, Frank Wood, and Leonid Sigal. Improved few-shot visual classification. In CVPR, pages 14493--14502, 2020.
  • [3] Hakan Bilen and Andrea Vedaldi. Universal representations: The missing link between faces, text, planktons, and cat breeds. arXiv preprint arXiv:1701.07275, 2017.
  • [4] Schroeder Brigit and Cui Yin. Fgvcx fungi classification challenge. online, 2018.
  • [5] John Bronskill, Jonathan Gordon, James Requeima, Sebastian Nowozin, and Richard Turner.

    Tasknorm: Rethinking batch normalization for meta-learning.

    In ICML, pages 1153--1164, 2020.
  • [6] Wei-Yu Chen, Yen-Cheng Liu, Zsolt Kira, Yu-Chiang Frank Wang, and Jia-Bin Huang. A closer look at few-shot classification. In ICLR, 2019.
  • [7] Yinbo Chen, Xiaolong Wang, Zhuang Liu, Huijuan Xu, and Trevor Darrell. A new meta-baseline for few-shot learning. arXiv preprint arXiv:2003.04390, 2020.
  • [8] Zhao Chen, Vijay Badrinarayanan, Chen-Yu Lee, and Andrew Rabinovich. Gradnorm: Gradient normalization for adaptive loss balancing in deep multitask networks. In ICML, pages 794--803. PMLR, 2018.
  • [9] Mircea Cimpoi, Subhransu Maji, Iasonas Kokkinos, Sammy Mohamed, and Andrea Vedaldi. Describing textures in the wild. In CVPR, pages 3606--3613, 2014.
  • [10] Kevin Clark, Minh-Thang Luong, Urvashi Khandelwal, Christopher D Manning, and Quoc V Le. Bam! born-again multi-task networks for natural language understanding. In ACL, 2019.
  • [11] Guneet S Dhillon, Pratik Chaudhari, Avinash Ravichandran, and Stefano Soatto. A baseline for few-shot image classification. In ICLR, 2020.
  • [12] Carl Doersch, Ankush Gupta, and Andrew Zisserman. Crosstransformers: spatially-aware few-shot transfer. In NeurIPS, 2020.
  • [13] Nikita Dvornik, Cordelia Schmid, and Julien Mairal. Selecting relevant features from a multi-domain representation for few-shot classification. In ECCV, pages 769--786, 2020.
  • [14] Chelsea Finn, Pieter Abbeel, and Sergey Levine. Model-agnostic meta-learning for fast adaptation of deep networks. In ICLR, pages 1126--1135, 2017.
  • [15] Tommaso Furlanello, Zachary C Lipton, Michael Tschannen, Laurent Itti, and Anima Anandkumar. Born again neural networks. In ICML, 2018.
  • [16] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Deep residual learning for image recognition. In CVPR, pages 770--778, 2016.
  • [17] Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the knowledge in a neural network. In

    NeurIPS Deep Learning Workshop

    , 2014.
  • [18] Sebastian Houben, Johannes Stallkamp, Jan Salmen, Marc Schlipsing, and Christian Igel. Detection of traffic signs in real-world images: The german traffic sign detection benchmark. In IJCNN, pages 1--8. Ieee, 2013.
  • [19] Jonas Jongejan, Rowley Henry, Kawashima Takashi, Kim Jongmin, and Fox-Gieg Nick. The quick, draw! a.i. experiment. online, 2016.
  • [20] Alex Kendall, Yarin Gal, and Roberto Cipolla. Multi-task learning using uncertainty to weigh losses for scene geometry and semantics. In CVPR, pages 7482--7491, 2018.
  • [21] Gregory Koch, Richard Zemel, and Ruslan Salakhutdinov. Siamese neural networks for one-shot image recognition. In ICML deep learning workshop, volume 2. Lille, 2015.
  • [22] Simon Kornblith, Mohammad Norouzi, Honglak Lee, and Geoffrey Hinton. Similarity of neural network representations revisited. In ICML, pages 3519--3529. PMLR, 2019.
  • [23] Alex Krizhevsky, Geoffrey Hinton, et al. Learning multiple layers of features from tiny images. Citeseer, 2009.
  • [24] Brenden Lake, Ruslan Salakhutdinov, Jason Gross, and Joshua Tenenbaum. One shot learning of simple visual concepts. In Proceedings of the annual meeting of the cognitive science society, volume 33, 2011.
  • [25] Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. Human-level concept learning through probabilistic program induction. Science, 350(6266):1332--1338, 2015.
  • [26] Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278--2324, 1998.
  • [27] Wei-Hong Li and Hakan Bilen. Knowledge distillation for multi-task learning. In ECCV Workshop on Imbalance Problems in Computer Vision, pages 163--176. Springer, 2020.
  • [28] Tsung-Yi Lin, Michael Maire, Serge Belongie, James Hays, Pietro Perona, Deva Ramanan, Piotr Dollár, and C Lawrence Zitnick.

    Microsoft coco: Common objects in context.

    In ECCV, pages 740--755. Springer, 2014.
  • [29] Lu Liu, William Hamilton, Guodong Long, Jing Jiang, and Hugo Larochelle. A universal representation transformer layer for few-shot image classification. In ICLR, 2021.
  • [30] Jiaqi Ma and Qiaozhu Mei. Graph representation learning via multi-task knowledge distillation. In NeurIPS GRL Workshop, 2019.
  • [31] Subhransu Maji, Esa Rahtu, Juho Kannala, Matthew Blaschko, and Andrea Vedaldi. Fine-grained visual classification of aircraft. arXiv preprint arXiv:1306.5151, 2013.
  • [32] Thomas Mensink, Jakob Verbeek, Florent Perronnin, and Gabriela Csurka. Distance-based image classification: Generalizing to new classes at near-zero cost. TPAMI, 35(11):2624--2637, 2013.
  • [33] Erik G Miller, Nicholas E Matsakis, and Paul A Viola. Learning from one example through shared densities on transforms. In CVPR, volume 1, pages 464--471. IEEE, 2000.
  • [34] Thao Nguyen, Maithra Raghu, and Simon Kornblith. Do wide and deep networks learn the same things? uncovering how neural network representations vary with width and depth. In ICLR, 2021.
  • [35] Alex Nichol, Joshua Achiam, and John Schulman. On first-order meta-learning algorithms. arXiv preprint arXiv:1803.02999, 2018.
  • [36] Maria-Elena Nilsback and Andrew Zisserman. Automated flower classification over a large number of classes. In 2008 Sixth Indian Conference on Computer Vision, Graphics & Image Processing, pages 722--729. IEEE, 2008.
  • [37] Hyun Oh Song, Yu Xiang, Stefanie Jegelka, and Silvio Savarese. Deep metric learning via lifted structured feature embedding. In CVPR, pages 4004--4012, 2016.
  • [38] Adam Paszke, Sam Gross, Francisco Massa, Adam Lerer, James Bradbury, Gregory Chanan, Trevor Killeen, Zeming Lin, Natalia Gimelshein, Luca Antiga, Alban Desmaison, Andreas Kopf, Edward Yang, Zachary DeVito, Martin Raison, Alykhan Tejani, Sasank Chilamkurthy, Benoit Steiner, Lu Fang, Junjie Bai, and Soumith Chintala. Pytorch: An imperative style, high-performance deep learning library. In H. Wallach, H. Larochelle, A. Beygelzimer, F. dAlché-Buc, E. Fox, and R. Garnett, editors, NeurIPS, pages 8024--8035. Curran Associates, Inc., 2019.
  • [39] Ethan Perez, Florian Strub, Harm De Vries, Vincent Dumoulin, and Aaron Courville. Film: Visual reasoning with a general conditioning layer. In AAAI, volume 32, 2018.
  • [40] Mary Phuong and Christoph Lampert. Towards understanding knowledge distillation. In ICML, pages 5142--5151, 2019.
  • [41] Sylvestre-Alvise Rebuffi, Hakan Bilen, and Andrea Vedaldi. Learning multiple visual domains with residual adapters. In NeurIPS, 2017.
  • [42] Sylvestre-Alvise Rebuffi, Hakan Bilen, and Andrea Vedaldi. Efficient parametrization of multi-domain deep neural networks. In CVPR, pages 8119--8127, 2018.
  • [43] Mengye Ren, Eleni Triantafillou, Sachin Ravi, Jake Snell, Kevin Swersky, Joshua B Tenenbaum, Hugo Larochelle, and Richard S Zemel. Meta-learning for semi-supervised few-shot classification. In ICLR, 2018.
  • [44] James Requeima, Jonathan Gordon, John Bronskill, Sebastian Nowozin, and Richard E Turner. Fast and flexible multi-task classification using conditional neural adaptive processes. In CVPR, 2019.
  • [45] Adriana Romero, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang, Carlo Gatta, and Yoshua Bengio. Fitnets: Hints for thin deep nets. In ICLR, 2015.
  • [46] Olga Russakovsky, Jia Deng, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, et al. Imagenet large scale visual recognition challenge. IJCV, 115(3):211--252, 2015.
  • [47] Tonmoy Saikia, Thomas Brox, and Cordelia Schmid. Optimized generic feature learning for few-shot classification across domains. arXiv preprint arXiv:2001.07926, 2020.
  • [48] Jake Snell, Kevin Swersky, and Richard S Zemel. Prototypical networks for few-shot learning. In NeurIPS, 2017.
  • [49] Flood Sung, Yongxin Yang, Li Zhang, Tao Xiang, Philip HS Torr, and Timothy M Hospedales. Learning to compare: Relation network for few-shot learning. In CVPR, pages 1199--1208, 2018.
  • [50] Yonglong Tian, Dilip Krishnan, and Phillip Isola. Contrastive representation distillation. In ICLR, 2020.
  • [51] Yonglong Tian, Yue Wang, Dilip Krishnan, Joshua B Tenenbaum, and Phillip Isola. Rethinking few-shot image classification: a good embedding is all you need? In ECCV, 2020.
  • [52] Eleni Triantafillou, Tyler Zhu, Vincent Dumoulin, Pascal Lamblin, Utku Evci, Kelvin Xu, Ross Goroshin, Carles Gelada, Kevin Swersky, Pierre-Antoine Manzagol, et al. Meta-dataset: A dataset of datasets for learning to learn from few examples. In ICLR, 2020.
  • [53] Oriol Vinyals, Charles Blundell, Timothy Lillicrap, Koray Kavukcuoglu, and Daan Wierstra. Matching networks for one shot learning. In NeurIPS, 2016.
  • [54] Catherine Wah, Steve Branson, Peter Welinder, Pietro Perona, and Serge Belongie. The caltech-ucsd birds-200-2011 dataset. California Institute of Technology, 2011.
  • [55] Lu Yu, Vacit Oguz Yazici, Xialei Liu, Joost van de Weijer, Yongmei Cheng, and Arnau Ramisa. Learning metrics from teachers: Compact networks for image embedding. In CVPR, pages 2907--2916, 2019.
  • [56] Tianhe Yu, Saurabh Kumar, Abhishek Gupta, Sergey Levine, Karol Hausman, and Chelsea Finn. Gradient surgery for multi-task learning. In ICLR, 2020.
  • [57] Matthew D Zeiler. Adadelta: an adaptive learning rate method. arXiv preprint arXiv:1212.5701, 2012.

Appendix A Implementation details

In all experiments we build our method on ResNet-18 [16] backbone for both single-domain and multi-domain networks.

a.1 Training details of single-domain models

We train one ResNet-18 model for each training dataset. For optimization, we follow the training protocol in [13]. Specifically, we use SGD optimizer and cosine annealing for all experiments with a momentum of 0.9 and a weight decay of . The learning rate, batch size, annealing frequency, maximum number of iterations are shown in Table 6. To regularize training, we also use the exact same data augmentations as in [13], e.g. random crops and random color augmentations.

Dataset learning rate batch size annealing freq. max. iter.
ImageNet 64 48,000 480,000
Omniglot 16 3000 50,000
Aircraft 8 3000 50,000
Birds 16 3000 50,000
Textures 32 1500 50,000
Quick Draw 64 48,000 480,000
Fungi 32 15,000 480,000
VGG Flower 8 1500 50,000
Table 6: Training hyper-parameters of single domain learning.

a.2 Training details of our method

In the multi-domain network, we share all the layers but the last classifier across the domains. To train the multi-domain network, we use the same optimizer with a weight decay of and a scheduler as single domain learning model for learning 240,000 iterations. The learning rate is 0.03 and the annealing frequency is 48,000. Similar to [52] that the training episodes have 50% probability coming from the ImageNet data source, each training batch for our multi-domain network consists of 50% data coming from ImageNet. In other words. The batch size for ImageNet is and is for the other 7 datasets.

We set and as 4 for ImageNet and 1 for other datasets, respectively. And we linearly anneal by , where, is the current iteration and is the total number of iterations to anneal to zero. Here, , where . is 48, 000 in this work. We search the based on cross-validation over the validation sets of 8 training datasets and is 5 (i.e. ) for ImageNet, is 2 for Omniglot, Quick Draw, Fungi and is 1 for other datasets. For all experiments, early-stopping is performed based on cross-validation over the validations sets of 8 training datasets.

For the optimization of feature adaptation during meta-test stage, we initialize as an indentity matrix, which allows the NCC to use the original features produced by our universal network and optimize the adaptor from a good start point. Similar to the optimization in [13], we optimize for 40 iterations using Adadelta [57] as optimizer with a learning rate of 0.1 for first eight datasets and 1 for the last five datasets.

Appendix B More results

In this section, we first evaluate each single-domain model for few-shot classification on each test dataset. Then we evaluate the effect of the adaptors for aligning features in knowledge distillation. Then we show complete results on varying-way five-shot and five-way one-shot settings. Finally more qualitative results and global retrieval results are reported.

b.1 Complete results of single domain learning

Test DatasetTrain Dataset ImageNet Omniglot Aircraft Birds Textures Quick Draw Fungi Vgg Flower
ImageNet
Omniglot
Aircraft
Birds
Textures
Quick Draw
Fungi
VGG Flower
Traffic Sign
MSCOCO
MNIST
CIFAR-10
CIFAR-100
Table 7: Results of all single domain learning models. Mean accuracy and 95% confidence interval are reported. The first eight datasets are seen during training and the last five datasets are unseen for test only.
Five-Shot Five-Way One-Shot
Test Dataset Simple SUR URT Ours Simple SUR URT Ours
CNAPS [2] [13] [29] CNAPS [2] [13] [29]
ImageNet
Omniglot
Aircraft
Birds
Textures
Quick Draw
Fungi
VGG Flower
Traffic Sign
MSCOCO
MNIST
CIFAR-10
CIFAR-100
Average Rank 3.0 3.0 2.5 1.5 2.8 3.5 2.3 1.3
Table 8: Results of Five-Way One-Shot and Varying-Way Five-Shot settings. Mean accuracies are reported and the results with confidence interval are reported.

To study the universal representation learning from multiple datasets, we train one network on each training dataset and use each single-domain network as the feature extractor and test it for few-shot classification in each dataset. This involves evaluating 8 single-domain networks on 13 datasets using Nearest Centroid Classifier (NCC). Table 7 shows the results of single domain learning models, where each column present the mean accuracy and 95% confidence interval of a single-domain network trained on one dataset (e.g. ImageNet) and evaluated on 13 test datasets. The average accuracy and 95% confidence intervals computed over 600 few-shot tasks. The numbers in bold indicate that a method has the best accuracy per dataset.

As shown in Table 7, the feature of the ImageNet model generalizes well and achieves the best results on four out of eight seen datasets, e.g. ImageNet, Birds, Texture, VGG Flower and four out of five previously unseen datasets, e.g. Traffic Sign, MSCOCO, CIFAR-10, CIFAR-100. The models trained on Omniglot, Aircraft, Quick Draw, and Fungi perform the best on the corresponding datasets while the Omniglot model also generalizes well to MNIST which has the similar style images to Omniglot. We then pick the best performing model, forming the best single-domain model (Best SDL) which serves a very competitive baseline for universal representation learning.

Test Dataset Ours (CKA w/o ) Ours (CKA)
ImageNet
Omniglot
Aircraft
Birds
Textures
Quick Draw
Fungi
VGG Flower
Traffic Sign
MSCOCO
MNIST
CIFAR-10
CIFAR-100
Table 9: Results of our method using CKA, CKA without adaptors (i.e. ). Mean accuracy and 95% confidence interval are reported. Here, Ours (CKA w/o ) indicates that adaptors are not applied for aligning features. All results are obtained with feature adaptation during meta-test stage.
Test Dataset ImageNet Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower
Recall@ 1 2 4 8 1 2 4 8 1 2 4 8 1 2 4 8 1 2 4 8 1 2 4 8 1 2 4 8 1 2 4 8
Sum
Concate
MDL
Simple CNAPS [2]
Ours
Table 10: Global retrieval performance on Meta-Dataset (seen datasets). In addition to few-shot learning experiments, we evaluate our method in a non-episodic retrieval task to further compare the generalization ability of our universal representations.
Test Dataset Traffic Sign MSCOCO MNIST CIFAR-10 CIFAR-100
Recall@ 1 2 4 8 1 2 4 8 1 2 4 8 1 2 4 8 1 2 4 8
Sum
Concate
MDL
Simple CNAPS [2]
Ours