Cross-Domain Few-Shot Learning with Meta Fine-Tuning

05/21/2020 ∙ by John Cai, et al. ∙ Princeton University 0

In this paper, we tackle the new Cross-Domain Few-Shot Learning benchmark proposed by the CVPR 2020 Challenge. To this end, we build upon state-of-the-art methods in domain adaptation and few-shot learning to create a system that can be trained to perform both tasks end-to-end. Inspired by the need to create models designed to be fine-tuned, we explore the integration of transfer-learning (fine-tuning) with meta-learning algorithms, to train a network that has specific layers that are designed to be adapted at a later fine-tuning stage. To do so, we modify the episodic training process to include a first-order MAML-based meta-learning algorithm, and use a Graph Neural Network model as the subsequent meta-learning module. We find that our proposed method helps to boost accuracy significantly, especially when coupled with data augmentation. In our final results, we combine the novel method with the baseline method in a simple ensemble, and achieve an average accuracy of 73.78 on the benchmark. This is a 6.51 were trained solely on miniImagenet.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

Code Repositories

Meta-Fine-Tuning

Submission for CVPR 2020 CDFSL Challenge


view repo
This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In the past decade, vast improvements to visual recognition systems have been achieved by training deep neural networks on ever-expanding training data-sets. Often, the ability of these neural network models to generalize directly depended on the size and variance of the training data-set. Unfortunately, acquiring a large training data-set is costly due to the need for human annotation. Furthermore, when dealing with rare examples in medical images (e.g. rare diseases) or satellite images (e.g. oil spills), the ability to obtain labelled samples is limited. Moreover, there is vast scope for improvement, as the human visual system is far less data-hungry that current deep learning methods.

To address the limitations of traditional deep learning, few-shot learning methods have emerged to train models that predict new classes by seeing only a few labelled images per class. These methods have shown promising improvements over the last 5 years.

However, existing few-shot learning methods have been developed with the assumption that the training and test data-set arise from the same distribution. Domain shift would thus be an additional problem as it may prevent the robust transfer of features.

Hence, the CVPR 2020 challenge has introduced a new benchmark that aims to test for generalization ability across a range of vastly different domains, with domains from natural and medical images, domains without perspective, and domains without color [6]. This robust evaluation framework allows one to truly test for how few-shot learning models do when faced with sharp domain shifts. This is in contrast to previous attempts at considering domain-adaptation [2], as the training and test domains previously used were still fairly close to each other.

The main contribution of this paper is the integration of fine-tuning into the episodic training process by exploiting a first-order MAML-based meta-learning algorithm (henceforth “Meta Fine-Tuning”). This is done so that the network learns a set of initial weights that be easily fine-tuned on the support-set of the test domain.

Second, this paper integrates the above Meta Fine-Tuning algorithm into a Graph Neural Network that exploits the non-Euclidean structure of the relation between the support set and the query samples.

Third, as the baseline code-base only implements data augmentation for the training process, we implement data augmentation on the support set during fine-tuning, and achieve a further improvement in accuracy.

Finally, we combine the above method with a modified fine-tuning baseline method, and combine them into an ensemble to jointly make predictions.

2 Relevant Work

The key method that this paper will build on is Graph Neural Networks. Graph-based convolutions can create more flexible representations of data [1]. For the context of few-shot learning, after obtaining image features through deep learning networks, the problem can be reconstrued as a belief propagation problem. Under this framework, labels from labelled support examples are propagated to unlabelled query examples. Hence, by representing the set of support examples as a fully-connected undirected graph, we can add each query example in, and learn edge weights [5].

Model-Agnostic Meta-Learning (MAML) was developed to train a network to learn a set of internal representations that can easily be adapted [4]. It has been shown that the first-order approximations of the MAML algorithms, such as Reptile [7], that ignore second-order derivatives perform as well on established benchmarks.

A key method for domain adaptation is to fix earlier feature layers, and fine-tune later feature layers on the support examples [6]. This could help transfer high-level features, while retraining domain-specific features.

Averaging the prediction scores of different models to reduce variance has been well-documented to achieve higher accuracy. In a few-shot learning context, randomness is higher because the model is learning to generalize from a few examples each time [3].

3 Methodology

3.1 Graph Neural Networks

Figure 1: GNN architecture from [5]

The Meta-Learning module used in this paper is the Graph Neural Network, which was applied to few-shot learning [5]. Labels are propagated from the support set and graph convolutions are performed in each layer, so that each example in the query set is connected to all support nodes in the graph.

3.2 Meta Fine-Tuning

The core idea of meta-fine tuning is that instead of fine-tuning a pre-trained model that was not trained explicitly for fine-tuning, we can use meta-learning to find a set of weight initializations that are intended to be fine-tuned. To this end, we apply and adapt the first-order MAML algorithm [7] and simulate the episodic training process. A first-order MAML algorithm can achieve comparable results with the second-order algorithm at a lower computational cost.

Initialize weights for feature extractor and for metric-learning module;
for each episode do
       Sample support samples and query samples
       Freeze first layers of feature extractor
       for step = 1,2,…, do
             Sample batch from the support samples
             Compute loss

on these support samples using linear classifier

             Update the last layers of feature extractor using SGD or Adam
       end for
      Obtain = , the updated weights for last layers
       Combine and to obtain new feature extractor
       Feed images through feature extractor and then through the metric-learning module
       Compute the loss on the query samples and compute update , , for all model parameters using Adam
       Update initial parameters using learning rate :
      
      
      
      
end for
Algorithm 1 Meta Fine-Tuning Algorithm

The algorithm is model-agnostic and can be used with any existing metric-learning module. However, typically a metric-learning module with weights should be used, as a completely non-parametric module like Prototypical Networks, which uses nearest centroid, may not be able to compare subsequent fine-tuned image features in a robust way. In this paper, we use Graph Neural Networks (GNN) as the metric-learning module, as the GNN is flexible and not limited to making comparisons in Euclidean space.

The method can also be applied to a model of any backbone depth, and you can freeze up to any number of layers. For this paper, we freeze the last ResNet block in ResNet10.

For further visualization, we include the figure below.

Figure 2: Meta-learning with fine-tuning. Green are trainable, brown are frozen. At test time, all will be frozen in step 2.
No. of Shots CropDisease EuroSAT ISIC ChestX
5 96.27% 0.40% 89.83% 0.46% 61.71% 0.44% 28.44% 0.43%
20 98.91% 0.19% 93.90% 0.34% 65.29% 0.56% 35.62% 0.47%
50 99.48% 0.13% 96.08% 0.25% 75.13% 0.56% 44.70% 0.64%
Table 1: Final Proposed Model: Meta Fine-Tuning GNN + Modified Baseline Fine-Tuning + Data Augmentation
No. of Shots CropDisease EuroSAT ISIC ChestX
5 88.72% 0.53% 80.45% 0.54% 47.20% 0.45% 25.96% 0.46%
20 95.76% 0.65% 87.67% 0.44% 59.95% 0.45% 31.63% 0.49%
50 97.87% 0.48% 90.93% 0.45% 65.04% 0.47% 37.03% 0.50%
Table 2: Previous Benchmark’s Best Model “Ft-Last1” trained on MiniImagenet from [6]

During step 1 (Meta Fine-Tuning), only support images are used, and the first 8-layers are frozen. A linear classifier on the ResNet10 features is used to predict the support labels, and the last 2-layers are updated accordingly. At step 2, all layers are updated using the episodic training loss. At prediction stage on the test domain, all layers in the ResNet10 will be frozen in step 2.

3.3 Data Augmentation

For data augmentation during training, we stick to the default parameters in the code-base. For data augmentation during testing, we sample 17 additional images from the support images (which we know the labels of), and perform jitter, random crops, and horizontal flips (if applicable) on a randomized basis. In the fine-tuning process, we weight the original images more by exposing the model to the original images more frequently. At the final prediction stage, only base images (which are center-crops) are used for both support and query images.

3.4 Combining Scores in the Ensemble

The baseline fine-tuning model was modified so that we only fine-tune then last ResNet Block, and use an Adam optimizer with weight decay and 20 epochs. For our final submission results, we combine the predictions from the modified baseline fine-tuning model and the meta fine-tuning GNN model by normalizing the scores using a softmax function so that the scores from each model sum to 1 and are between 0 and 1. Then we add them together and take argmax. I also implement transduction (as it is allowed in Track 1).

3.5 Memory Requirements of GNN on 50-shot

The GNN builds a fully-connected graph between all the support samples and each new query sample. Hence, the space requirement for the 50-shot is tremendous, as the memory requirements scale up at

. Thus, in order to fit the model onto a 16GB Tesla V100, we average every 2 support samples’ feature vectors into 1, so that we obtain 25 nodes for the GNN.

4 Submission Results (Table 1)

4.1 Experimental Setup

The experimental setup involves training on miniImagenet and testing on CropDisease, EuroSAT, ISIC and ChestX. Models are trained for 400 epochs and then meta fine-tuned for 200 epochs. During training, we augment the training dataset using image transformations, following the protocol in [2]. For the ensemble models and the data augmentation, we modify the evaluation code in [6] to ensure that the same test images are used and add checks to ensure that the same base images are used per-episode despite the randomization in testing.

4.2 Proposed Model

As shown on Table 1 and Table 2, the final proposed model vastly outperforms the previous benchmark introduced by [6]. The average accuracy across all 12 tasks in the proposed model (Table 1) is 73.78% while the average accuracy in the previous benchmark (Table 2), which is at 67.27%. This is a 6.51% improvement in the benchmark model that was trained solely on the miniImagenet dataset.

Even if we compare the accuracy with the performance of the benchmark model that was trained on multiple datasets (while the proposed model is not, due to the requirements of the challenge), the proposed model still has a vast improvement, as the previous IMS-f model achieved an average accuracy of 68.69% [6].

We can also further observe that the improvement in accuracy is most pronounced at the 5-shot level, with a 8.48% improvement in accuracy over the baseline Ft-Last1 model. This is followed by a 6.38% improvement at the 50-shot level and a 4.68% improvement at the 20-shot level. The non-linear improvement in the model may be attributed to the effect of data augmentation versus meta fine-tuning: data augmentation likely has the most effect when the number of support examples is very low, while fine-tuning has the most effect when the number of support examples is very high. This is supported by subsequent analysis.

No. of Shots CropDisease EuroSAT ISIC ChestX
5 96.14% 0.43% 87.13% 0.58% 53.00% 0.45% 26.76% 0.45%
20 98.66% 0.43% 95.01% 0.33% 62.72% 0.73% 32.83% 0.45%
Table 3: Single Model Study: Meta Fine-Tuning GNN + Data Augmentation
No. of Shots CropDisease EuroSAT ISIC ChestX
5 92.23% 0.46% 82.67% 0.50% 61.76% 0.50% 31.60% 0.41%
20 95.95% 0.30% 87.84% 0.46% 60.32% 0.59% 35.91% 0.42%
Table 4: Single Model Study: Modified Baseline Fine-Tuning + Data Augmentation

5 Further Analysis of Results

5.1 Single Model Study

For brevity, we only show results for 5 and 20 shot for the individual models. We see a clear pattern that meta fine-tuning is contributing more to improve accuracies on domains that are close to the training domain. We also see that baseline fine-tuning + DA is most effective at domains that are more distant from the training domain such as ChestX. We also find that the improvement going from 5 to 20 shots is less pronounced for baseline fine-tuning + DA.

5.2 Ablation Study: GNN and Simple Fine-Tuning

To investigate the effects of data augmentation and meta fine-tuning,we perform the ablation study below. Simple fine-tuning refers to taking an existing GNN model and simply fine-tuning the last ResNet block. For brevity, we present results for EuroSAT and ISIC below for 20-shot.

GNN Method EuroSAT ISIC
No FT 86.57% 0.63% 52.32% 0.64%
Simp FT 90.30% 0.47% 61.04% 0.64%
Simp FT + DA 94.60% 0.37% 63.34% 0.72%
Meta FT + DA 95.01% 0.33% 62.72% 0.73%
Table 5: 20-Shot: Further experiments with fine-tuning of GNN. DA refers to data augmentation

From above, we see that Meta Fine-Tuning and Simple Fine-Tuning achieve comparable results when paired with domain adaptation. It can be seen that in domains that are more similar to miniImagenet, Meta Fine-Tuning performs slightly better, while for domains further away from the training set, Simple Fine-Tuning performs better. Still, the above study validates the use of fine-tuning and meta-learning together for cross-domain few-shot learning.

6 Conclusion

In this paper,we have developed a model that outperforms the benchmark by 6.51%, using meta fine-tuning with GNN, data augmentation and ensemble methods.

The individual model performance suggests that Meta Fine-Tuning does especially well at domains close to the source domain like CropDisease or EuroSAT, while traditional simple fine-tuning with data augmentation works better on domains that are further away. One reason is that learning how to fine tune on the training dataset may have restricted the effectiveness of fine-tuning in domain-shift, as the fine-tuning process has become domain-specific.

The ablation study suggests that most of the improvement in accuracy is driven by the use of fine-tuning with meta-learning, followed by data augmentation, with less evidence of the benefits of MAML-based algorithms for meta fine-tuning. Further research can be done to train a meta fine-tuning process that is more domain-agnostic, or look into combining meta fine-tuning with simple fine-tuning to create a model that can handle any domain shift.

References

  • [1] M. M. Bronstein, J. Bruna, Y. LeCun, A. Szlam, and P. Vandergheynst (2017) Geometric deep learning: going beyond euclidean data. IEEE Signal Processing Magazine 34 (4), pp. 18–42. Cited by: §2.
  • [2] W. Chen, Y. Liu, Z. Kira, Y. F. Wang, and J. Huang (2019) A closer look at few-shot classification. arXiv preprint arXiv:1904.04232. Cited by: §1, §4.1.
  • [3] N. Dvornik, C. Schmid, and J. Mairal (2019) Diversity with cooperation: ensemble methods for few-shot classification. In

    Proceedings of the IEEE International Conference on Computer Vision

    ,
    pp. 3723–3731. Cited by: §2.
  • [4] C. Finn, P. Abbeel, and S. Levine (2017) Model-agnostic meta-learning for fast adaptation of deep networks. In

    Proceedings of the 34th International Conference on Machine Learning-Volume 70

    ,
    pp. 1126–1135. Cited by: §2.
  • [5] V. Garcia and J. Bruna (2017) Few-shot learning with graph neural networks. arXiv preprint arXiv:1711.04043. Cited by: §2, Figure 1, §3.1.
  • [6] Y. Guo, N. C. Codella, L. Karlinsky, J. R. Smith, T. Rosing, and R. Feris (2019) A new benchmark for evaluation of cross-domain few-shot learning. arXiv preprint arXiv:1912.07200. Cited by: §1, §2, Table 2, §4.1, §4.2, §4.2.
  • [7] A. Nichol, J. Achiam, and J. Schulman (2018) On first-order meta-learning algorithms. arXiv preprint arXiv:1803.02999. Cited by: §2, §3.2.