Few-shot learning via tensor hallucination

04/19/2021 ∙ by Michalis Lazarou, et al. ∙ 0

Few-shot classification addresses the challenge of classifying examples given only limited labeled data. A powerful approach is to go beyond data augmentation, towards data synthesis. However, most of data augmentation/synthesis methods for few-shot classification are overly complex and sophisticated, e.g. training a wGAN with multiple regularizers or training a network to transfer latent diversities from known to novel classes. We make two contributions, namely we show that: (1) using a simple loss function is more than enough for training a feature generator in the few-shot setting; and (2) learning to generate tensor features instead of vector features is superior. Extensive experiments on miniImagenet, CUB and CIFAR-FS datasets show that our method sets a new state of the art, outperforming more sophisticated few-shot data augmentation methods.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Deep learning continuously keeps improving the state of the art in multiple different fields, such as natural language understanding Mikolov et al. (2010)

and computer vision 

Krizhevsky et al. (2012). However, even though the success of deep learning models is undeniable, a fundamental limitation is their dependence on large amounts of labeled data. This limitation inhibits the application of state of the art deep learning methods to real-world problems, where the cost of annotating data is high and data can be scarce, e.g. rare species classification.

To address this limitation, few-shot learning has attracted significant interest in recent years. One of the most common lines of research is meta-learning, where training episodes mimic a few-shot task by having a small number of classes and a limited number of examples per class. Meta-learning approaches can further be partitioned in optimization-based, learning to update the learner’s meta-parameters Finn et al. (2017); Zintgraf et al. ; Nichol et al. (2018); Ravi & Larochelle (2016), metric-based, learning a discriminative embedding space where novel examples are easy to classify Koch et al. (2015); Snell et al. (2017); Sung et al. (2018); Vinyals et al. (2016) and model-based, depending on specific model architectures to learn how to update the learner’s parameters effectively Santoro et al. (2016); Munkhdalai & Yu (2017).

Beyond meta-learning, other approaches include leveraging the manifold structure

of the data, by label propagation, embedding propagation or graph neural networks 

Garcia & Bruna (2017); Kim et al. (2019); Liu et al. (2018); Lazarou et al. (2020); and domain adaptation, reducing the domain shift between source and target domains Dong & Xing (2018); Hsu et al. (2017). Another line of research is data augmentation, addressing data deficiency by augmenting the few-shot training dataset with extra examples in the image space Chen et al. (2019); Zhang et al. (2018) and in the feature space Chen et al. (2019); Li et al. (2020); Luo et al. (2021). Such methods go beyond standard augmentation Krizhevsky et al. (2012) towards synthetic data generation and hallucination, achieving a greater extent of diversity.

Our work falls into the category of data augmentation in the feature space. We show that using a simple loss function to train a feature hallucinator can outperform other state of the art few-shot data augmentation methods that use more complex and sophisticated generation methods, such as wGAN Li et al. (2020), VAE Luo et al. (2021) and networks trained to transfer example diversity Chen et al. (2020). Also, to the best of our knowledge, we are the first to propose generating tensor features instead of vector features in the few-shot setting.

2 Method

2.1 Problem formulation

We are given a labeled dataset , with each example having a label in one of the classes in . This dataset is used to learn a parametrized mapping from an input image space to a feature or embedding space, where feature tensors have dimensions (channels) and spatial resolution (height width).

The knowledge acquired at representation learning is used to solve novel tasks, assuming access to a dataset , with each example being associated with one of the classes in , where is disjoint from . In few-shot classification Vinyals et al. (2016), a novel task is defined by sampling a support set from , consisting of classes with labeled examples per class, for a total of examples. Given the mapping and the support set , the problem is to learn an -way classifier that makes predictions on unlabeled queries, also sampled from . Queries are treated independently of each other. This is referred to as inductive inference.

2.2 Representation learning

The goal of representation learning is to learn the embedding function that can be applied to to extract embeddings and solve novel tasks. We use followed by global average pooling (GAP) and a parametric base classifier to learn the representation. We denote by the composition of and GAP. We follow the two-stage regime by Tian et al. (2020) to train our embedding model. In the first stage, we train on using standard cross-entropy loss :

(1)

where is a regularization term. In the second stage, we adopt a self-distillation process: The embedding model and classifier from the first stage serve as the teacher and we distill their knowledge to a new student model and classifier , with identical architecture. The student is trained using a linear combination of the standard cross-entropy loss, as in stage one, and the Kullback-Leibler (KL) divergence between the student and teacher predictions:

(2)

where and are scalar weights and are fixed.

2.3 Feature tensor hallucinator

All existing feature hallucination methods are trained using vector features, losing significant spatial and structural information. By contrast, our hallucinator is trained on the tensor features before global average pooling and generates tensor features as well. In particular, we use the student model , pre-trained using (2), as our embedding network to train our tensor feature hallucinator. The hallucinator consists of two networks: a conditioner network and a generator network . The conditioner aids the generator in generating class-conditional examples. Given a set of examples associated with class for , conditioning is based on the prototype tensor of each class ,

(3)

The conditioner maps the prototype tensor to the class-conditional vector . The generator takes as input this vector as well as a -dimensional sample

from a standard normal distribution and generates a

class-conditional tensor feature for class .

2.4 Training the hallucinator

We train our hallucinator using a meta-training regime, similar to Li et al. (2020); Chen et al. (2019); Schwartz et al. (2018). At every iteration, we sample a new episode by randomly sampling classes and examples for each class from . We obtain the prototype tensor for each class by (3) and the class-conditional vector by the conditioner . For each class , we draw samples from a normal distribution and we generate class-conditional tensor features using the generator . We train our hallucinator on the episode data by minimizing the mean squared error (MSE) of generated class-conditional tensor features of class to the corresponding class prototype :

(4)

2.5 Inference

At inference, we are given a few-shot task with a support set of classes with examples for each class . We compute the tensor feature of each example using our trained backbone network and obtain the prototype of each class by (3). Then, using our trained tensor feature hallucinator , we generate class-conditional tensor features for each class , also in , where are drawn from . We augment the support features with the generated features , resulting in labeled tensor features per class in total. We now apply GAP to those tensor features and obtain new, vector class prototypes in . Finally, we also apply GAP to the query tensor features and classify each query to the class of the nearest prototype.

3 Experiments

3.1 Setup

Datasets

We carry out experiments on three commonly used few-shot classication benchmark datasets: miniImagenet, CUB and CIFAR-FS. Further details are provided in subsection A.1.

Tasks

We consider -way, -shot classification tasks with randomly sampled novel classes and examples drawn at random per class as support set , that is, examples in total. For the query set , we draw additional examples per class, that is, examples in total, which is the most common choice Liu et al. (2018); Li et al. (2019); Yu et al. (2020).

Competitors

We compare our method with state-of-the-art data augmentation methods for few-shot learning, including MetaGAN Zhang et al. (2018), -encoder Schwartz et al. (2018), salient network (SalNet) Zhang et al. (2019), diversity transfer network (DTN) Chen et al. (2020), dual TriNet Chen et al. (2019), image deformation meta-network (IDeMe-Net) Chen et al. (2019), adversarial feature hallucination network (AFHN) Li et al. (2020) and variational inference network (VI-Net) Luo et al. (2021).

Networks

Many recent competitors Chen et al. (2019, 2019); Li et al. (2020); Luo et al. (2021) use ResNet-18 as backbone embedding model. To perform as fair comparison as possible, we use the same backbone.

Our tensor feature hallucinator (TFH) consists of a conditioner network and a generator network. The conditioner

consists of two convolutional layers with a ReLU activation in-between, followed by flattening and a fully-connected layer. The

generator consists of concatenation of and into , followed by reshaping to

and three transpose-convolutional layers with ReLU activation functions in-between and a sigmoid function in the end. More details are provided in

subsection A.2.

We also provide an improved solution, called TFH-ft, where our tensor feature hallucinator is fine-tuned on novel-class support examples at inference.

Baselines

To validate the benefit of generating tensor features, we also implement a vector feature hallucinator (VFH), where we use including GAP (2) as embedding model. In this case, the conditioner consists of two fully-connected layers with a ReLU activation in-between. The generator also consists of two fully-connected layers with a ReLU activation in-between and a sigmoid function in the end.

Finally, we experiment with baselines consisting of the embedding network  (1) or  (2) at representation learning and a prototypical classifier at inference, without feature hallucination. We refer to them as Baseline (1) and Baseline-KD (2) respectively.

width=1 Method Backbone miniImageNet CUB CIFAR-FS 1-shot 5-shot 1-shot 5-shot 1-shot 5-shot MetaGAN Zhang et al. (2018) ConvNet-4 52.710.64 68.630.67 -Encoder Schwartz et al. (2018) VGG-16 59.90 69.70 69.800.46 82.600.35 66.70 79.80 SalNet Zhang et al. (2019) ResNet-101 62.220.87 77.950.65 DTN Chen et al. (2020) Resnet-12 63.450.86 77.910.62 72.00 85.10 71.50 82.80 Dual TriNet Chen et al. (2019) ResNet-18 58.801.37 76.710.69 69.61 84.10 63.410.64 78.430.64 IDeMe-Net Chen et al. (2019) ResNet-18 59.140.86 74.630.74 AFHN Li et al. (2020) ResNet-18 62.380.72 78.160.56 70.531.01 83.950.63 68.320.93 81.450.87 VI-Net Luo et al. (2021) ResNet-18 61.05 78.60 74.76 86.84 Baseline (1) ResNet-18 56.810.81 78.310.59 67.140.89 86.220.50 65.710.95 84.680.61 Baseline-KD (2) ResNet-18 59.620.85 79.640.62 70.850.90 87.640.48 69.150.94 85.890.59 VFH (ours) ResNet-18 61.920.85 77.020.64 75.250.86 87.960.48 72.600.93 84.260.67 TFH (ours) ResNet-18 64.250.85 80.100.61 75.830.91 88.170.48 73.880.87 85.920.61 TFH-ft (ours) ResNet-18 63.920.86 80.410.60 75.390.86 88.720.47 73.890.88 87.150.58

Table 1: Comparison of our proposed method variants and baselines to state of the art on few-shot classification datasets. : Delta-encoder uses VGG-16 backbone for miniImageNet and CIFAR-FS and ResNet-18 for CUB. Baseline (1), Baseline-KD (2): prototypical classifier at inference, no feature generation. VFH: our vector feature hallucinator; TFH: our tensor feature hallucinator; TFH-ft: our tensor feature hallucinator followed by fine-tuning at inference.

3.2 Results

Table 1 compares our method with the state of the art. Most important are comparisons with Chen et al. (2019, 2019); Li et al. (2020); Luo et al. (2021), which use the same backbone, ResNet-18. Our tensor feature hallucinator provides new state of the art performance in all datasets and all settings, outperforming all competing few-shot data augmentation methods. Fine-tuning at inference is mostly beneficial, especially at 5-shot tasks. This is expected, as more data means less risk of overfitting. It is clear that the tensor feature hallucinator is superior to the vector feature hallucinator, while the latter is still very competitive. Self-distillation also provides a significant boost of performance in all experiments.

4 Conclusion

Our solution is conceptually simple and improves the state of the art of data augmentation methods in the few-shot learning setting. We provided experimental evidence showing that using a simple loss function and exploiting the structural properties of tensors can provide significant improvement in performance. Notably, the importance of using tensor features is evident through comparison with vector features, which are unable to achieve similar performance. Potential future directions include investigating the performance of our method with different backbone architectures and other experimental settings beyond few-shot learning.

References

Appendix A Appendix

a.1 Dataset details

miniImageNet

This is a widely used few-shot image classification dataset Vinyals et al. (2016); Ravi & Larochelle (2016). It contains 100 randomly sampled classes from ImageNet Krizhevsky et al. (2012). These 100 classes are split into 64 training (base) classes, 16 validation (novel) classes and 20 test (novel) classes. Each class contains 600 examples (images). We follow the commonly used split provided by Ravi & Larochelle (2016).

Cub

This is a fine-grained classification dataset consisting of 200 classes, each corresponding to a bird species. We follow the split defined by Chen et al. (2019); Hilliard et al. (2018), with 100 training, 50 validation and 50 test classes.

Cifar-Fs

This dataset is derived from CIFAR-100 Krizhevsky et al. (2009), consisting of 100 classes with 600 examples per class. We follow the split provided by Chen et al. (2019), with 64 training, 16 validation and 20 test classes.

All images from all datasets are resized to in a similar way to other data augmentation methods Li et al. (2020); Chen et al. (2019, 2019); Luo et al. (2021)

a.2 Implementation details

Our implementation is based on PyTorch Paszke et al. (2017).

Networks

In our tensor feature hallucinator (TFH), the embedding dimension is and the resolution is .

The convolutional layers of the conditioner use kernels of size

and stride 1 and in the input layer we also use padding 1. The channel dimensions are 512 and 256 for the first and second convolutional layers respectively. The dimension of the

class-conditional vector is set to . The tensor dimensions of all conditioner layers are , , (flattening) and .

All three transpose-convolutional layers of the generator use kernels of size , stride 1 and 512 channels. The dimension of is . The tensor dimensions of all generator layers are , , , and .

In our vector feature hallucinator (VFH), the dimensions of the class-conditional vector as well as the hidden layers of both the conditioner and the generator are all set to 512.

Training

For the embedding model, similarly to Tian et al. (2020), we use SGD optimizer with learning rate 0.05, momentum 0.9 and weight decay 0.0005. For data augmentation, as in Lee et al. (2019), we adopt random crop, color jittering, and horizontal flip.

The tensor feature hallucinator is trained in a meta-training regime with classes, examples per class and generation of class-conditioned examples in every task. We use Adam optimizer with initial learning rate

, decaying by half at every 10 epochs. We train for 50 epochs, where each epoch consists of 600 randomly sampled few-shot learning tasks. At test time, we find that generating more class-conditioned examples improves the accuracy, therefore we generate

class-conditioned examples.

Our TFH-ft version uses the novel-class support examples to fine-tune all of its parameters. In the fine-tuning stage, we use exactly the same loss function as in the hallucinator training phase (4) and we fine-tune for 10 steps using Adam optimizer and learning rate of .