Class Regularization: Improve Few-shot Image Classification by Reducing Meta Shift

12/18/2019 ∙ by Da Chen, et al. ∙ 0

Few-shot image classification requires the classifier to robustly cope with unseen classes even if there are only a few samples for each class. Recent advances benefit from the meta-learning process where episodic tasks are formed to train a model that can adapt to class change. However, these tasks are independent to each other and existing works mainly rely on limited samples of individual support set in a single meta task. This strategy leads to severe meta shift issues across multiple tasks, meaning the learned prototypes or class descriptors are not stable as each task only involves their own support set. To avoid this problem, we propose a concise Class Regularization Network which aggregates the embedding features of all samples in the entire training set and further regularizes the generated class descriptor. The key is to train a class encoder and decoder structure that can encode the embedding sample features into a class domain with trained class basis, and generate a more stable and general class descriptor from the decoder. We evaluate our work by extensive comparisons with previous methods on two benchmark datasets (MiniImageNet and CUB). The results show that our method achieves state-of-the-art performance over previous work.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

I Introduction

Fig. 1: A high-level description of meta shift. and are the support set and the class descriptor of the -th class in task respectively. During training, baseline method (A) generates class descriptor that is task-dependent and only concerns the classification result in the current task. The generated class descriptors across tasks are not stable and have shifting problem,  e.g., the first class is biased due to and are very close. The proposed method (B) utilizes a class domain with memory to regularize the class descriptor construction and avoid descriptor shifting.

The human ability to understand new concepts with only a few examples has inspired the research on few-shot learning in the recent years. The main task is to achieve a learnt model that can classify new category given limited training data. Different from classification model [10, 23, 28] trained on large labeled dataset, few-shot learning model only relies on a few samples of each class (10, 5 ,or even less). This may easily lead to overfitting during training. To address this issue, Vinyals et al. [29] propose an attention mechanism which can learn an embedding of labelled samples from the support set and achieve good classification performance. This mechanism can be further enhanced by episodes, which aim to sub-sample categories and the associated data to simulate few-shot tasks during training. The episodic training process also benefits meta-learning [21], which can learn a probabilistic model to predict a decision boundary between categories given a few samples from each category.

By focusing on retrieving transferable embedding from the dataset, as well as the relation between images and the associated category descriptions, existing meta-learning approaches, such as ProtoNet [24] and RelationNet [27], prove to be effective for few-shot learning. However, they are often restricted by only targeting individual tasks during training, thus cannot efficiently explore the variation of all classes in the whole training set with a more general view. Due to the bias of chosen class subset for an individual training episode, the generated class descriptors/prototype for the same class in different episodes can be sparsely distributed in the feature space. As illustrated in Fig. 1A, the meta task target of ProtoNet [24]i.e., class descriptors/prototypes (noted descriptor for the rest of the paper) of classes, are only required to be distinguishable under certain metric for the current task. The class descriptors in different tasks can still be distant to each other in the feature space, as individual tasks are considered separately. In this case, the class descriptors are not stable among tasks during training, thus the average embedding easily causes misclassification in the test stage. This is defined as meta shift problem.

To solve the meta shift problem, we propose the Class Regularization Network to stabilize class descriptors across all tasks in the whole meta-learning process (see Fig. 1B). The resultant class descriptor can be generally applied to classify various query samples in a stable manner. The network first leverages comprehensive embedding sample features to form a class representation in the class domain composed by trained semantic basis via a class encoder. The generated class representation can then be decoded as a more stable and general class descriptor in the feature domain. The obtained class descriptors are applied to classify the query samples via a metric module which is combination of fix distance measure and a trainable relation module.

We demonstrate that the proposed method achieves state-of-the-art results on two benchmark datasets comparing to classic methods and more recent methods. For 1-shot, 5-shot and 10-shot tasks on MiniImageNet, it improves from 61.20% to 63.37%, 76.70% to 79.03%, and 80.80% to 83.30% respectively. For fine-grained dataset CUB, it achieves nearly 5% improvement on 1-shot task and improves 5-shot result from 81.90% to 83.53%. In addition, we illustrate that the proposed Class Regularization module effectively fine-tunes the embedding network to better tackle the covariate shift issue in few-shot learning. Further quantitative and qualitative evaluations show that the proposed method can effectively regularize the descriptor construction and improve classification results.

Fig. 2: Meta shift measurement: The statistics on the distance between class descriptors of class ‘worm’ and their mean in 500 tests for baseline method (A) and the proposed method (B). Five samples are randomly picked(5-shot task) to construct the class descriptor. This process is repeated for 500 tests to generate 500 class descriptors for each method. The Euclidean distance from the ’mean’ class descriptor to the class descriptor in each test indicates the variation of the generated class descriptors which can be applied to measure the meta shift issue.

Ii Related Work

Few-shot learning as an active research topic has been extensively studied. A number of works aim to improve the robustness of the training process for few-shot learning. Garcia et al. [7]

propose a graph neural network according to the generic message-passing inference method. Zhao 

et al. [33]

split the features to three orthogonal parts to improve the classification performance for few-shot learning, allowing simultaneous feature selection and dense estimation. Chen 

et al. [3] present a comprehensive analysis and investigate the cross-domain generalization ability for many existing few-shot learning methods. They also propose a new method which achieves state-of-the-art result on the CUB dataset [30]. Chen et al. [34] propose a Self-Jig algorithm to augment the input data in few-shot learning by synthesizing new images that are either labelled or unlabelled. Chu et al. [4]

augment the input images by extracting varying sequences of patches on every forward-pass with discriminative observed information using maximum entropy reinforcement learning.

A popular strategy for few-shot learning is through meta-learning (also called learning-to-learn) with multi-auxiliary tasks [5, 6, 9, 13, 17, 20, 27, 29, 32]. The key insight is how to robustly accelerate the network learning progress without overfitting on limited training data. Finn et al. [6] propose MAML to search for the best initial weights through gradient descent for the network training, making the fine-tuning of the network easier. REPTILE [18] simplifies the complex computation of MAML by incorporating a loss, but still performs in high dimensional space. To reduce the complexity, Rusu et al. [22] propose a network called LEO to learn a low dimension latent embedding of the model. CAML [13] extends MAML by partitioning the model parameters into context parameters and shared parameters, enabling a larger network without over-fitting.

Another stream of meta-learning based approaches [19, 24, 27, 29] attempt to learn a deep embedding model that can effectively project the input samples to a specific feature space. Then the samples can be classified by the nearest neighbour criterion using a distance function such as Cosine distance or Euclidean distance, etc. Koch et al. [2] propose the Siamese network to extract embedding features from input images and converge images in the same class. Matching Network [29] utilizes an augmented neural network for feature embedding, forming the basis for metric learning. The most relevant work to ours is ProtoNet [24]. It proposes to model the class descriptor of each class with a simple average pooling on embedding sample features. However, the distance estimation only concerns the current task but not the whole train set. Many approaches extend ProtoNet by improving its class descriptor. RelationNet [27] exploits a relation network with a learnable non-linear comparator instead of a fixed linear one. TADAM [19] produces a task-dependent metric space based on conditioning a learner on the task set. Li et al. [14] propose to replace the distance measured at the image level to a local descriptor based image-to-class measure. Liu et al. [16] propose a transductive propagation network (TPN) which learns both the parameters of feature embedding and the graph construction. Unlike methods mentioned above which fine-tune the embedding network directly, Sun et al. [26]

propose to apply a neuron-level scaling and shifting on the embedding network with a hard-task meta batch setting.

In this paper, we proposed a novel network that can learn much more stable class descriptors to generally represent all classes across the whole meta learning process. This avoids the meta shift and achieves state-of-the-art performance.

Iii Class Regularization Network

In this section, we first present preliminaries in the meta-learning setting with independent episodic tasks and the subsequent meta shift problem, then give an overview of our approach, finally elaborate the details of our method.

Iii-a Meta-learning for few-shot learning

Few-shot learning is a challenging problem as it is required to classify unseen queries with only limited data for training. One solution is to apply a meta-learning process composed by multiple -way -shot episodic classification tasks [29]. For each classification task, we have classes with samples from each class. The entire training dataset can be presented by , where is the total number of samples in , and is the corresponding label of sample . For a specific task , a support set and a query set are randomly selected from : (a) the support set for task is denoted by , where (-way -shot); (b) the query set is , where is the number of samples for query set in each meta task.

Fig. 3: Different components of bird are highlighted by different colors. This indicates the semantic basis in the class domain. The saliency of components implies the class location in class domain.

Iii-B Meta shift problem

Fig. 4: Overview of the proposed Class Regularization Network. The Network is composed by four major parts: A feature Embedding Network (), Class Encoder () for generating class representation with semantic basis in class domain, Class Decoder () to transform the class representation back to feature domain and form a general class descriptor and Metric Module to compare class descriptor and query sample feature.

During the meta-learning process, most of the existing methods only focus on the selected samples in the support and query sets ( and ) of the current task (such as 5-way, 5-shot classification), while ignoring the overall sample distribution within the training set . We spot that for different training tasks, support samples from the same class can be different as they are picked randomly from the entire training set . This leads to unstable episodic task training in meta learning. This unstable issue among meta learning tasks is defined as meta shift.

For example, as mention in Section II, the recent popular framework ProtoNet utilizes a feature embedding network to map all input samples in support set from the

-th class to a mean vector

. They take this as a class descriptor in the common feature space:

(1)

where is the embedding function.

As it only focuses on the current task, the generated class descriptors are sparsely distributed in the feature space. As shown in Fig. 2, the variation of the generated class descriptors is considerably high (magenta curve), hindering the ability to accurately classify various query images.

Different from prior work, we aim at a better regularization over all the training tasks using a class regularization network. This network is supposed to resolve the meta shift problem and generate more stable feature descriptors (see Figure 1B) and further yield significantly improved performance for classification.

Iii-C Key idea

To solve this problem, we consider how humans classify objects with few examples. Conceptually, humans can find key components of the objects and perform classification accordingly. For example, as shown in Fig. 3, the feather of wings (green), shape of neck (blue), mouth (red), and tail (yellow) can be very helpful to classify birds. These key components can be treated as semantic basis to form a class domain, and all classes can be represented in this domain according to the basis.

Inspired by the above, we propose a class encoder-decoder framework to form a class domain with trained semantic basis and then aggregate the sample features (within the same class) from the feature domain to the class domain. As there are much more samples of each class in the whole training set compared with individual support set, the embedding features are redundant to describe a class with sparse feature distribution. For each task, the aggregated sample features can update the encoder as a memory of how to form a better class representation in the class domain with semantic basis. Semantic basis based class representation can effectively avoid task-dependent problem which causes meta shift problem. This makes the encoder effective to describe various classes in a more stable way shown as the cyan curve in Fig. 2. Each class representation in the class domain is further reconstructed back into the feature domain to form a more stable and general class descriptor and classify query samples according to the metric module. Benefiting from a robust embedding network and a stable class descriptor, the proposed method improves classification performance. An overview of the proposed network is shown in Fig. 4.

Iii-D Proposed method

In this subsection, we present the details of the network structure, including the embedding network, the class encoder-decoder network, and the metric module.

Iii-D1 Embedding samples

Similar to previous metric learning based methods, the input samples are first fed into an embedding network.

(2)

where is the embedding feature of the input sample from either support set or query set .

Iii-D2 Encoding classes

Given the embedding features, we use a class encoder to generate class representation with trainable semantic basis in the class domain. This is done by aggregating class features composed by the embedding features for all samples in the -th class in support set. The representation of the -th class in the class domain is denoted as , which can be written as:

(3)

For simplicity, we use and to replace and in the rest of this subsection.

Suppose class embedding contains sample features (-shot task) and is formed by semantic basis, it is easy to see that can be represented by a matrix ( is the size of the embedding feature for a single sample), and can be represented as a matrix.

Similar to [8], by updating the form of semantic basis in each episodic task, we statistically gather information for class representation based on all samples in that class across all tasks. More specifically, the aggregation of the residual between each sample feature in one class and the semantic basis is applied to construct the class representation. For each class in an episodic task, the encoder constructs class representation based on multiple semantic basis. Mathematically, the representation of each semantic basis (as the -th row in ) is updated based on in the following way:

(4)

where is the -th sample feature in the current task. if is the closest to , otherwise . The sum of residuals are further assigned to . The weighted residual indicates the saliency of the sample feature response to each semantic basis. After a normalization operation, the output will be a class representation in the class domain.

However, the operation based on Eq. 4 is not differentiable due to the discontinuities when assigning to semantic basis . To solve this problem, is replaced by a soft assignment of sample features to the semantic basis:

(5)

where parameter controls the sensitivity to the residual.

It can be seen that the sample feature is assigned to basis in a weighted manner using weight . And the class representation class domain. The encoder parameter can be represented by . More details can be found in the supplementary material.

Iii-D3 Decoding classes

With the aforementioned encoded class representation in the class domain, we are able to decode and transform them from class domain back into feature domain as a class descriptor which is more general and stable.

In the class decoder, we include a transformation function to obtain the general class descriptor given the class representation in the class domain:

(6)

Note that the dimension of the transformed class descriptor is also , which is the same as the embedding feature as they are both in the feature domain.

As shown in Fig. 1 and Fig. 2, ProtoNet simply applies average pooling on sample features of one class to obtain class descriptor. As the input support set only has a small number of samples in each task, the associated features are not consistent among tasks, raising meta shift issue where the same class has unstable descriptors for different tasks.

On the other hand, our method extracts better class descriptor based on the distribution of all classes. This is done by decomposing the class representation into trainable semantic basis in a class domain. The decoded class descriptor proves to be more stable among all tasks. More details are discussed in Section IV.

Iii-D4 Metric Module

To classify the query image, classic few-shot classification methods such as ProtoNet [24] apply a simple fixed distance metric i.e., Cosine distance, Euclidean distance, etc. In this work, we include a general relation module [27] along with a fixed distance metric via a co-training mechanism as a metric module, resulting in a more robust similarity measurement for comparing the embedding feature of a query image to the general class descriptor of the -th class. Here we discuss the metric module in details.

Similar to ProtoNet [24], we employ a Euclidean distance function and produce a distribution over all classes given a sample from the query set

. The distribution is based on a softmax over the distance between the sample feature (in the query set) and the general class descriptor. The loss function can be defined as:

(7)

Apart from the fixed distance function, we also include a trainable relation module. Given the embedding feature of the query image and class descriptor , the output of the relation module indicates the correlation between the two. We treat this output as the relation score [27] among classes and all query images. Similar to Eq. 7, given the correct class label for , the loss for the relation module can be written as:

(8)

where is the relation function, denotes concatenation. The total loss of the network can then be summarized as:

(9)

where are the weights of the two losses and the regularization term, are the training parameters of class encoder, class decoder and relation module. The algorithmic process of this module and the entire network can be found in the supplementary material.

Iv Evaluation

In this section, we first describe the network architecture and the datasets in our evaluation, then detail the training procedure, finally show quantitative comparisons with baseline methods, and a thorough discussion of the advantage over the most relevant work [24].

Iv-a Network architecture

The overall architecture of the proposed network is shown in Fig. 4. Noted that it can incorporate different feature embedding networks. In our experiments, we apply two networks that are usually used in the few-shot image classification field [14, 26], including ResNet12 which is a residual network [10]

with three residual blocks followed by an average pooling layer for general feature embedding, and 4 Conv with four convolutional blocks for feature extraction. The class encoder has a convolutional block and a softmax operator for soft-assignment (see details in the supplemental) to produce a sparse matrix which includes all weights for semantic basis. With a multiplication between the weights and the residuals followed by a normalization layer, the class representation composed by semantic basis in the class domain is obtained. The decoder exploits a normal convolutional block to transform the class representation from the class domain to general class descriptor in the feature domain. To compare the obtained general class descriptor with the embedding features of the samples in the query set, a fixed distance metric (

i.e., Euclidean distance) and a learnable relation module (two Conv + two FC) based similarity measurement are jointly applied, while a cross-entropy loss is employed as classification loss. The network then updates the metric module along with the encoder-decoder structure, and fine-tune the feature embedding network simultaneously.

Iv-B Dataset

MiniImageNet dataset, as proposed in [29], is a benchmark to evaluate few-shot learning methods. This dataset is a randomly selected subset from ImageNet – a popular image classification dataset. MiniImageNet contains 60,000 images from only 100 classes, and each class has 600 images. We also follow the data split strategy in [21] to sample images of 64 classes for training, 16 classes for validation, 20 classes for test.

Caltech-UCSD CUB-200-2011 dataset [30] is a dataset for fine-grained classification. The CUB-200-2011 dataset contains 200 classes of birds with 11788 images in total. For few-shot learning classification task, we follow the split in [12] for evaluation. 200 species of birds are randomly split to 100 classes for training, 50 classes for validation, and 50 classes for test.

 

Embedding Net Image Size Basis -rate Pre -rate

 

ResNet12 16
Conv 8

 

TABLE I: Training parameters for different embedding networks on MiniImageNet and CUB datasets.

 

Baselines Embedding Net 1-Shot 5-Way 5-Shot 5-Way 10-Shot 5-Way

 

Matching Networks [29] 4 Conv 43.56 0.84% 55.31 0.73% -
MAML [6] 4 Conv 48.70 1.84% 63.11 0.92% -
RelationNet [27] 4 Conv 50.44 0.82% 65.32 0.70% -
REPTILE [18] 4 Conv 49.97 0.32% 65.99 0.58% -
ProtoNet [24] 4 Conv 49.42 0.78% 68.20 0.66% 74.30%
Baseline* [3] 4 Conv 41.08 0.70% 54.50% 0.66 -
Spot&learn [4] 4 Conv 51.03 0.78% 67.96% 0.71 -
DN4 [14] 4 Conv 51.24 0.74% 71.02% 0.64 -
Ours_Conv 4 Conv 52.0 0.20% 70.98% 0.16 -
Discriminative k-shot [1] ResNet34 56.30 0.40% 73.90 0.30% 78.50 0.30%
Self-Jig(SVM) [34] ResNet50 58.80 1.36% 76.71 0.72% -
Qiao-WRN [20] Wide-ResNet28 59.60 0.41% 73.74 0.19% -
LEO [22] Wide-ResNet28 61.76 0.08% 77.59 0.12% -
SNAIL [17] ResNet12 55.71 0.99% 68.88 0.92% -
ProtoNet [24] ResNet12 56.50 0.40% 74.2 0.20% 78.60 0.40%
RelationNet [27] ResNet12 58.20 0.30% 74.35 0.23% 78.65 0.30%
CAML [13] ResNet12 59.23 0.99% 72.35 0.71% -
TPN [16] ResNet12 59.46% 75.65% -
MTL [26] ResNet12 61.20 1.8% 75.50 0.8% -
DN4 [14] ResNet12 54.37 0.36% 74.44 0.29% -
TADAM [19] ResNet12 58.50% 76.70% 80.80%

 

Ours_Res ResNet12 63.37 0.14% 79.03 0.16% 83.30 0.13%

 

TABLE II: Few-shot classification accuracy results on MiniImageNet

on 1-shot 5-way, 5-shot 5-way and 10-shot 5-way tasks. All accuracy results are reported with 95% confidence intervals. For each task, the best-performing method is highlighted. ‘-’: the results are not reported. ‘+’: re-implementation results for a fair comparison.

 

Baselines Embedding Net 1-Shot 5-Way 5-Shot 5-Way

 

MatchingNet [29] 4 Conv 61.16 0.89 72.86 0.70
MAML [6] 4 Conv 55.92 0.95% 72.09 0.76%
ProtoNet [24] 4 Conv 51.31 0.91% 70.77 0.69%
MACO [12] 4 Conv 60.76% 74.96%
RelationNet [27] 4 Conv 62.45 0.98% 76.11 0.69%
ProtoNet [24] 4 Conv 63.52 0.25% 79.06 0.20%
Baseline++ [3] 4 Conv 60.53 0.83% 79.34 0.61%
DN4-DA [14] 4 Conv 53.15 0.84% 81.90 0.60%

 

Ours_conv 4 Conv 67.85 0.24% 83.53 0.16%

 

TABLE III: Few-shot classification accuracy results on CUB dataset [30] on 1-shot 5-way task, 5-shot 5-way task. All accuracy results are reported with 95% confidence intervals. For each task, the best-performing method is highlighted.

Iv-C Training details

Several recent works show that a typical training process can include a pre-trained network [20, 22] or employ co-training [19] for feature embedding. This can significantly improve the classification accuracy. In this paper, we adapt the training strategy from [20] to pre-train the feature embedding network as discussed in Section IV-A. In our training process, we follow the standard setup as applied by most few-shot learning frameworks, introduced in Section III-A

, to train the embedding network, class encoder and decoder along with the metric module. Stochastic gradient descent (SGD) is used as the optimizer for ResNet embedding network, while ADAM is chosen as the optimizer for Conv embedding network. During testing, we follow the setup in 

[24]. 15 query images per class are batched in the testing episode. The accuracy is obtained by averaging 600 test tasks which generate test batches with random classes and random samples. Table I

shows the training parameters with different embedding networks on the two datasets. The training time for proposed method is closed to ProtoNet as the class regularization module has a simple architecture( 75 second per epoch on a single P100 machine for both methods). The number of semantic basis is an important hyper-parameter. See supplemental for its relation with classification accuracy.

Iv-D Comparing with other baselines

Fig. 5: t-SNE plot of latent space features for 20 class in test set in MiniImageNet and 10 classes in CUB datasets. 50 samples are randomly selected from each classes. (A) ProtoNet (Left) and Ours (Right) in the CUB dataset; (B) same as (A) but in MiniImageNet dataset. Colors of the feature points represent class labels as in the right box.
Fig. 6: FID score [11]: The curves show the result of 500 tests. For each test, 25 samples(5-way 5-shot) for training set and 75 samples(15 queries/class) for query set are randomly chosen to obtain its sample features by both methods. FID score measures the similarity of training and testing feature distribution for baselines.

As detailed in Table II, the proposed method is compared with baselines including classic methods and more recent methods. It can be seen that our method achieves state-of-the-art results in all three tasks with different embedding networks on MiniImageNet dataset. For 1-shot 5-way test, based on ResNet12 embedding network, our approach achieves and improvement over ProtoNet and MTL [26] respectively. The former is an amended variant of ProtoNet [24] using pre-trained network as embedding network for a fair comparison. The later is a more recent work. For 5-shot 5-way test, we observe a similar accuracy improvement of over ProtoNet. Note that our method also yields a margin of increment over the current state-of-the-art approach LEO [22], which applies a 28-layer wide residual network. Furthermore, we see that the performance of the proposed method is better while receiving more images(shots) as input. For example, our accuracy is in 5-shot 5-way test against for 1-shot 5-way. This is because more samples per-class provide more information to our network, thus further improve the descriptiveness of the learned aggregated class descriptor. Moreover, as shown in Table II, most of the existing methods apply a 12-layer residual network [10] or a 4 Conv blocks network for feature embedding. Some of them even apply a deeper feature embedding network with much more parameters to train, such as ResNet-50 used by Sel-Jig [34], ResNet-34 used by Discriminative k-shot [1], and a 28-layer wide residual network used by both LEO [22] and Qiao-WRN [20]. Even so, our method still achieves better performance applying a simpler embedding network with 30% less parameters comparing to state-of-the-art method LEO.

Table III summarizes the comparisons on the CUB dataset. Our method yields the highest accuracy from all trials. In the 1-shot 5-way test, we achieves accuracy which improves to the classic Matching Networks [29] by . More significant improvement is achieved for 5-shot 5-way test. Our classification accuracy is , an improvement of over ProtoNet. Comparing to the recently proposed Baseline++ [3], our method demonstrates a significant improvement of and in tests.

Note that different baseline methods take input images with different sizes. Although most methods set image size to , larger image can help to improve classification accuracy with better data augmentation condition. For instance, Sel-Jig [34] relies on images with data augmentation, while our method only requires images without data augmentation.

Iv-E Effectiveness of Class Regularization

Meta Shift: To highlight the advantage of the proposed method against other non-class-regularization baseline ( i.e., ProtoNet [24]), we measure the stability of the class descriptors of the proposed method against other methods having meta shift issue. Note that for a fair comparison, we use adapted ProtoNet with the same pre-trained embedding network. For each class in the test set of MiniImageNet, we randomly pick 5 samples to construct the class descriptor. We repeat this process for 500 tests to generate 500 class descriptors for each baseline. The Euclidean distance from the ‘mean’ class descriptor to specific class descriptor in each test indicates the variation of the generated class descriptors. Fig. 2 shows statistics for the class ‘worm’. The proposed method (cyan curve with much less variation) is able to generate more stable class descriptors comparing to ProtoNet (magenta curve with severe feature shifting). This results in better performance for the few-shot classification task. More comparisons of meta shift express in the supplementary material.

Covariate Shift Given limited training data (even one sample per class for 1-shot task), when dealing with unseen classes during meta testing, the training data distribution is likely to be different from the distribution of the query data (covariate shift [25]). It requires a generalized embedding network to embed them into a similar feature distribution and be classified with a stable class descriptor from training.

Similar to the qualitative evaluation in [22], we use t-SNE projection to visualize the feature space of ProtoNet and our method. For clarity, we randomly choose 50 samples from each class in the test set and extract their features from the embedding network. The t-SNE plot of the feature space is shown in Fig 5 (the plot of all samples is included in the supplementary material). The feature space of the proposed method is observed to be different from the one of ProtoNet in both MiniImageNet and CUB dataset. Our embedding features are densely clustered within each class while sparsely distributed among different classes, demonstrating intra-class commonality and inter-class distinctiveness. As training and testing sets are randomly selected from these samples, this qualitative result proves that the proposed class regularization effectively fine-tunes the embedding network and better deal with covariate shift issues.

We also compare the similarity of training sample features against query sample features. This quantitatively measures the efficiency of the embedding networks which are fine-tuned by each baseline. In this work, we apply Fréchet Inception Distance(FID) score [11], which is widely used in generative model evaluation, to measure the feature distribution difference between the training samples and testing samples. Ideally, with a close distribution, the classifier can perform better. As shown in Fig 6, the embedding network fine-tuned by the proposed class regularization method consistently less than the embedding network fine-tuned by ProtoNet, indicating higher similarity and a better embedding network.

The qualitative and quantitative results show that the proposed method can not only tackle the meta shift issue but also efficiently fine-tune the embedding network to alleviate the covariate shift issue in few-shot learning. These two benefits make the proposed method significantly outperform other baselines.

Metric module In addition, we further perform an ablation study by simply applying Euclidean distance for image classification (same as ProtoNet). The result is shown in Table IV. For clarity, we also show the performance of ProtoNet with the same embedding network. It can be seen that due to the effectiveness of the proposed class regularization method, the network without relation module still outperforms baselines listed in Table II and Table III. Relation module combined with Euclidean distance further improves the performance.

 

Dataset Framework 1-shot 5-shot

 

MiniImageNet Res+Euc 56.50 0.40 74.2 0.20
Res+ACD+Euc 62.97 0.16% 78.87 0.15%
Res+ACD+(Euc+R) 63.37 0.14% 79.03 0.16%

 

CUB Conv+Euc 63.52 0.25% 79.06 0.20%
Conv+ACD+Euc 66.80 0.22% 82.0 0.15%
Conv+ACD+(Euc+R) 67.85 0.24% 83.53 0.16%

 

TABLE IV: Ablation study on two datasets by comparing ProtoNet (top), our method without relation module (middle), and with relation module (bottom).

V Conclusion and Future Work

In this paper, to tackle the meta shift problem, we propose a novel meta-learning framework that employs class regularization for better feature embedding and better class descriptor over the entire training set. We utilize an aggregation of image embeddings for better extraction of finer details during the training process, along with an effective representation of the target classes using aggregated descriptors learned from the framework. We evaluate the proposed framework by comparing with existing few-shot learning methods on the MiniImageNet and CUB datasets. The state-of-the-art performance demonstrates the advantage of our framework over previous works.

Several directions can be explored in the future. One way to improve the effectiveness of the class regularization module would be to involve task condition/embedding along with the current image embedding. Another direction is to further strengthen the embedding network by using FPN [15] or attention based techniques [31].

References

  • [1] M. Bauer, M. Rojas-Carulla, J. B. Świątkowski, B. Schölkopf, and R. E. Turner (2017) Discriminative k-shot learning using probabilistic models. arXiv preprint arXiv:1706.00326. Cited by: §IV-D, TABLE II.
  • [2] G. ç, R. Zemel, and R. Salakhutdinov (2015) Siamese neural networks for one-shot image recognition. In

    ICML Deep Learning Workshop

    ,
    Vol. 2. Cited by: §II.
  • [3] W. Chen, Y. Liu, Z. Kira, Y. F. Wang, and J. Huang (2019) A closer look at few-shot classification. In International Conference on Learning Representations, External Links: Link Cited by: §II, §IV-D, TABLE II, TABLE III.
  • [4] W. Chu, Y. Li, J. Chang, and Y. F. Wang (2019) Spot and learn: a maximum-entropy patch sampler for few-shot image classification. In

    Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition

    ,
    pp. 6251–6260. Cited by: §II, TABLE II.
  • [5] M. Douze, A. Szlam, B. Hariharan, and H. Jégou (2018) Low-shot learning with large-scale diffusion. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 3349–3358. Cited by: §II.
  • [6] C. Finn, P. Abbeel, and S. Levine (2017) Model-agnostic meta-learning for fast adaptation of deep networks. In

    Proceedings of the 34th International Conference on Machine Learning-Volume 70

    ,
    pp. 1126–1135. Cited by: §II, TABLE II, TABLE III.
  • [7] V. Garcia and J. Bruna (2017) Few-shot learning with graph neural networks. arXiv preprint arXiv:1711.04043. Cited by: §II.
  • [8] Y. Gong, L. Wang, R. Guo, and S. Lazebnik (2014) Multi-scale orderless pooling of deep convolutional activation features. In European conference on computer vision, pp. 392–407. Cited by: §III-D2.
  • [9] B. Hariharan and R. Girshick (2017) Low-shot visual recognition by shrinking and hallucinating features. In Proceedings of the IEEE International Conference on Computer Vision, pp. 3018–3027. Cited by: §II.
  • [10] K. He, X. Zhang, S. Ren, and J. Sun (2016) Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770–778. Cited by: §I, §IV-A, §IV-D.
  • [11] M. Heusel, H. Ramsauer, T. Unterthiner, B. Nessler, and S. Hochreiter (2017) Gans trained by a two time-scale update rule converge to a local nash equilibrium. In Advances in Neural Information Processing Systems, pp. 6626–6637. Cited by: Fig. 6, §IV-E.
  • [12] N. Hilliard, L. Phillips, S. Howland, A. Yankov, C. D. Corley, and N. O. Hodas (2018) Few-shot learning with metric-agnostic conditional embeddings. arXiv preprint arXiv:1802.04376. Cited by: §IV-B, TABLE III.
  • [13] X. Jiang, M. Havaei, F. Varno, G. Chartrand, N. Chapados, and S. Matwin (2019) Learning to learn with conditional class dependencies. In International Conference on Learning Representations, External Links: Link Cited by: §II, TABLE II.
  • [14] W. Li, L. Wang, J. Xu, J. Huo, Y. Gao, and J. Luo (2019) Revisiting local descriptor based image-to-class measure for few-shot learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7260–7268. Cited by: §II, §IV-A, TABLE II, TABLE III.
  • [15] T. Lin, P. Dollár, R. B. Girshick, K. He, B. Hariharan, and S. J. Belongie (2017) Feature pyramid networks for object detection.. In CVPR, Vol. 1, pp. 3. Cited by: §V.
  • [16] Y. Liu, J. Lee, M. Park, S. Kim, E. Yang, S. Hwang, and Y. Yang (2019) LEARNING TO PROPAGATE LABELS: TRANSDUCTIVE PROPAGATION NETWORK FOR FEW-SHOT LEARNING. In International Conference on Learning Representations, External Links: Link Cited by: §II, TABLE II.
  • [17] N. Mishra, M. Rohaninejad, X. Chen, and P. Abbeel (2017) A simple neural attentive meta-learner. arXiv preprint arXiv:1707.03141. Cited by: §II, TABLE II.
  • [18] A. Nichol and J. Schulman (2018) Reptile: a scalable metalearning algorithm. arXiv preprint arXiv:1803.02999 2. Cited by: §II, TABLE II.
  • [19] B. Oreshkin, P. R. López, and A. Lacoste (2018) TADAM: task dependent adaptive metric for improved few-shot learning. In Advances in Neural Information Processing Systems, pp. 719–729. Cited by: §II, §IV-C, TABLE II.
  • [20] S. Qiao, C. Liu, W. Shen, and A. L. Yuille (2018) Few-shot image recognition by predicting parameters from activations. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7229–7238. Cited by: §II, §IV-C, §IV-D, TABLE II.
  • [21] S. Ravi and H. Larochelle (2017) Optimization as a model for few-shot learning. In International Conference on Learning Representations, Cited by: §I, §IV-B.
  • [22] A. A. Rusu, D. Rao, J. Sygnowski, O. Vinyals, R. Pascanu, S. Osindero, and R. Hadsell (2019) Meta-learning with latent embedding optimization. In International Conference on Learning Representations, External Links: Link Cited by: §II, §IV-C, §IV-D, §IV-E, TABLE II.
  • [23] K. Simonyan and A. Zisserman (2014) Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556. Cited by: §I.
  • [24] J. Snell, K. Swersky, and R. Zemel (2017) Prototypical networks for few-shot learning. In Advances in Neural Information Processing Systems, pp. 4077–4087. Cited by: §I, §II, §III-D4, §III-D4, §IV-C, §IV-D, §IV-E, TABLE II, TABLE III, §IV, §VIII-B.
  • [25] M. Sugiyama, M. Krauledat, and K. MÞller (2007) Covariate shift adaptation by importance weighted cross validation. Journal of Machine Learning Research 8 (May), pp. 985–1005. Cited by: §IV-E.
  • [26] Q. Sun, Y. Liu, T. Chua, and B. Schiele (2019)

    Meta-transfer learning for few-shot learning

    .
    In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 403–412. Cited by: §II, §IV-A, §IV-D, TABLE II.
  • [27] F. Sung, Y. Yang, L. Zhang, T. Xiang, P. H. Torr, and T. M. Hospedales (2018) Learning to compare: relation network for few-shot learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 1199–1208. Cited by: §I, §II, §II, §III-D4, §III-D4, TABLE II, TABLE III, §VIII-B.
  • [28] C. Szegedy, S. Ioffe, V. Vanhoucke, and A. A. Alemi (2017)

    Inception-v4, inception-resnet and the impact of residual connections on learning.

    .
    In AAAI, Vol. 4, pp. 12. Cited by: §I.
  • [29] O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra, et al. (2016) Matching networks for one shot learning. In Advances in neural information processing systems, pp. 3630–3638. Cited by: §I, §II, §II, §III-A, §IV-B, §IV-D, TABLE II, TABLE III.
  • [30] C. Wah, S. Branson, P. Welinder, P. Perona, and S. Belongie (2011) The Caltech-UCSD Birds-200-2011 Dataset. Technical report Technical Report CNS-TR-2011-001, California Institute of Technology. Cited by: §II, §IV-B, TABLE III.
  • [31] F. Wang, M. Jiang, C. Qian, S. Yang, C. Li, H. Zhang, X. Wang, and X. Tang (2017) Residual attention network for image classification. arXiv preprint arXiv:1704.06904. Cited by: §V.
  • [32] Y. Wang, R. Girshick, M. Hebert, and B. Hariharan (2018) Low-shot learning from imaginary data. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pp. 7278–7286. Cited by: §II.
  • [33] B. Zhao, X. Sun, Y. Fu, Y. Yao, and Y. Wang (2018) MSplit lbi: realizing feature selection and dense estimation simultaneously in few-shot and zero-shot learning. arXiv preprint arXiv:1806.04360. Cited by: §II.
  • [34] C. Zitian, F. Yanwei, C. Kaiyu, and J. Yu-Gang (2019) Image block augmentation for one-shot learning.. In AAAI, Cited by: §II, §IV-D, §IV-D, TABLE II.

Vi Introduction

In this supplementary document, we present 1) additional results to show the effectiveness of the proposed class regularization network; 2) mathematical details of the class encoder-decoder module and metric module; 3) the algorithmic procedure of the proposed method.

Vii Effectiveness of Class Regularization on Meta shift reduction: More qualitative results and meta shift plots

As mentioned in the main paper, we also plot the t-SNE visualization of the feature space for ProtoNet and the proposed method. Fig. 7 includes all the samples of the test sets for both MiniImageNet and CUB datasets.

As an important component of the proposed framework, the number of the semantic basis in the class domain is essential to the performance of the proposed method. Fig. 8 shows the effectiveness of the semantic basis number on the accuracy of the proposed method.

Fig. 9 shows additional meta shift plots of more classes similar to the plot in Fig. 2 in the main paper. Note that all classes are from the test set of MiniImageNet (‘worm’, ‘crab’, ‘golden retriever’, ‘malamute’, ‘dalmatian’, ‘hyena dog’, ‘lion’, ‘ant’, ‘ferret’, ‘bookshop’).

Fig. 9: Additional meta shift plots on ten different classes in test set of MiniImageNet: The statistics on the distance between class descriptors and their mean in 500 tests for baseline method (ProtoNet) and the proposed method. Class name for the descriptors: row: ‘worm’, ‘crab’, row: ‘golden retriever’, ‘malamute’, row: ‘dalmatian’, ‘hyena dog’, row: ‘lion’, ‘ant’, row ‘ferret’, ‘bookshop’.
1:Initialize Proposed Network
2:for each  do meta training episodes
3:     for each  do tasks in one episode
4:          =
5:          = random sample result
6:          = sample sub-dataset for the task
7:         for each  do
8:               =
9:               =
10:               = , representation each class
11:         end for
12:         
13:         for each  do
14:              for each  do
15:                   = + +
16:                  where accumulate class re-construction loss
17:              end for
18:         end for
19:         Back propagation
20:     end for
21:end for
Algorithm 1 Algorithmic process for Class Regularization Network. is the entire training set; is the total number of class in ; is the number of episodes for meta training; is the number of tasks for one meta episode; is the classes selected for support set; is the number of classes sampled for each task in one meta episode; is the image set of the -th class in ; is the sampled support image set of the -th class, and each image is labelled; is the query image set to test the re-construction of the -th class, and each image is unlabelled; is the size of each support image set; is the size of each test/query image set;

Viii Detailed math of class domain construction and metric module

Viii-a Class domain construction

In this section, we include the background mathematical details of the proposed method. The notation is consistent with Section 3 in the main paper.

The embedding function of the proposed method can provide a feature embedding of given samples:

(10)

where is a tensor that represents the embedding features for all samples from the -th class in support set for the current training task. is the size of the embedding features for each input sample. The class representation in the class domain can be written as:

(11)

where is the class encoder function. We compose the class representation with the semantic basis which can be treated as a weighted summation of residuals:

(12)

where is the -th dimension of the -th embedding feature, is the -th dimension of the representation on -th semantic basis. if representation on semantic basis is the closest to embedding feat , otherwise .

(13)

where is a parameter that controls the sensitivity to the residual.

In this supplementary material, we expand Eq. 13 and obtain the final representation of Eq. 16. Eq 13 can be expanded as:

(14)

as the summation in denominator is only related to representation on semantic basis. Eq. 14 can be transformed as:

(15)

It can be seen that the embedding feature is assigned to semantic basis using Eq. 12 in a weighted manner, where the weight is noted in Eq. 13. And Eq. 12 can be re-written as follows:

(16)

Based on Eq. 12 and Eq. 15, Eq. 16 can be rewritten as:

(17)

Based on Eq. 17, given embedding features of samples from a class, we can get the class representation with semantic basis in the class domain. The parameter of encoder function can be represented as:

(18)

The obtained set is the trainable parameter set of the encoder for semantic basis .

Viii-B Metric Module

Similar to other metric learning based method [24], we employ a Euclidean distance function and produce a distribution over all classes given a sample from the query set :

(19)

As shown in Eq. 19, the distribution is based on a softmax over the distance between the embedding of the samples (in the query set) and the reconstructed features of the class. The loss function of our network can then read:

(20)

As mentioned in the main paper, a relation module is included in the metric module. The relation score is a matrix, in which the element can be represented as:

(21)

where is the relation module with two convolution layers and two fully connected layers, is a function that combines the two feature tensors. In this paper, we simply concatenate them as the input of the relation module. In [27], mean square error (MSE) is recommended to train the relation module:

(22)

where is the correct label for query image , is a dimension vector with element 1 when , otherwise 0. Noted that for each query image , the relation score is with the same dimension as

, and can be treated as a probability distribution over

classes. Similar to Eq. 19, given the correct label for , the distribution can be written as:

(23)

Hence, the loss for the relation module can also be written as:

(24)

The total loss of the network can then be summarized as:

(25)

where are the weights of the two losses and the regularization term, include the training parameters of class encoder, class decoder and metric module.

Ix Algorithmic procedure of class regularization network

The process of our proposed network can be visualized in Algorithm 21.