DPGN: Distribution Propagation Graph Network for Few-shot Learning

03/31/2020 ∙ by Ling Yang, et al. ∙ Megvii Technology Limited NetEase, Inc 0

We extend this idea further to explicitly model the distribution-level relation of one example to all other examples in a 1-vs-N manner. We propose a novel approach named distribution propagation graph network (DPGN) for few-shot learning. It conveys both the distribution-level relations and instance-level relations in each few-shot learning task. To combine the distribution-level relations and instance-level relations for all examples, we construct a dual complete graph network which consists of a point graph and a distribution graph with each node standing for an example. Equipped with dual graph architecture, DPGN propagates label information from labeled examples to unlabeled examples within several update generations. In extensive experiments on few-shot learning benchmarks, DPGN outperforms state-of-the-art results by a large margin in 5 semi-supervised settings.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 7

page 8

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

The success of deep learning is rooted in a large amount of labeled data 

[19, 38], while humans generalize well after having seen few examples. The contradiction between these two facts brings great attention to the research of few-shot learning [7, 20]. Few-shot learning task aims at predicting unlabeled data (query set) given a few labeled data (support set).

Figure 1:

Our proposed DPGN adopts contrastive comparisons between each sample with support samples to produce distribution representation. Then it incorporates distribution-level comparisons with instance-level comparisons when classifying the query sample.

Fine-tuning [4] is the defacto method in obtaining a predictive model from a small training dataset in practice nowadays. However, it suffers from overfitting issues [11]. Meta-learning [8] methods introduces the concept of episode to address the few-shot problem explicitly. An episode is one round of model training, where in each episode, only few examples (e.g., 1 or 5) are randomly sampled from each class in training data. Meta-learning methods adopt a trainer (also called meta-learner) which takes the few-shot training data and outputs a classifier. This process is called episodic training [41]. Under the framework of meta-learning, a diverse hypothesis was made to build an efficient meta-learner.

A rising trend in recent researches was to process the training data with Graph Networks [2], which is a powerful model that generalizes many data structures (list, trees) while introduces a combinatorial prior over data. Few-Shot GNN [10] is proposed to build a complete graph network where each node feature is concatenated with the corresponding class label, then node features are updated via the attention mechanism of graph network to propagate the label information. To further exploit intra-cluster similarity and inter-cluster dissimilarity in the graph-based network, EGNN [18]

demonstrates an edge-labeling graph neural network under the

episodic training framework. It is noted that previous GNN studies in few-shot learning mainly focused on pair-wise relations like node labeling or edge labeling, and ignored a large number of substantial distribution relations. Additionally, other meta-learning approaches claim to make use of the benefits of global relations by episodic training, but in an implicitly way.

As illustrated in Figure 1, firstly, we extract the instance feature of support and query samples. Then, we obtain the distribution feature for each sample by calculating the instance-level similarity over all support samples. To leverage both instance-level and distribution-level representation of each example and process the representations at different levels independently, we propose a dual-graph architecture: a point graph (PG) and a distribution graph (DG). Specifically, a PG generates a DG by gathering 1-vs-n relation on every example, while the DG refines the PG by delivering distribution relations between each pair of examples. Such cyclic transformation adequately fuses instance-level and distribution-level relations and multiple generations (rounds) of this Gather-Compare process concludes our approach. Furthermore, it is easy to extend DPGN to semi-supervised few-shot learning task where support set containing both labeled and unlabeled samples for each class. DPGN builds a bridge connection between labeled and unlabeled samples in the form of similarity distribution, which leads to a better propagation for label information in semi-supervised few-shot classification.

Our main contributions are summarized as follows:

  • To the best of our knowledge, DPGN is the first to explicitly incorporate distribution propagation in graph network for few-shot learning. The further ablation studies have demonstrated the effectiveness of distribution relations.

  • We devise the dual complete graph network that combines instance-level and distribution-level relations. The cyclic update policy in this framework contributes to enhancing instance features with distribution information.

  • Extensive experiments are conducted on four popular benchmark datasets for few-shot learning. By comparing with all state-of-the-art methods, the DPGN achieves a significant improvement of 5%12% on average in few-shot classification accuracy. In semi-supervised tasks, our algorithm outperforms existing graph-based few-shot learning methods by 7%13 %.

2 Related Work

2.1 Graph Neural Network

Graph neural networks were first designed for tasks on processing graph-structured data [34, 41]. Graph neural networks mainly refine the node representations by aggregating and transforming neighboring nodes recursively. Recent approaches [10, 25, 18] are proposed to exploit GNN in the field of few-shot learning task. TPN [25] brings the transductive setting into graph-based few-shot learning, which performs a Laplacian matrix to propagate labels from support set to query set in the graph. It also considers the similarity between support and query samples through the process of pairwise node features affinities to propagate labels. EGNN [18] uses the similarity/dissimilarity between samples and dynamically update both node and edge features for complicated interactions.

2.2 Metric Learning

Another category of few-shot learning approaches focus on optimizing feature embeddings of input data using metric learning methods. Matching Networks [41] produces a weighted nearest neighbor classifier through computing embedding distance between support and query set. Prototypical Networks [36] firstly build a prototype representation of each class in the embedding space. As an extension of Prototypical Networks, IMF [1] constructs infinite mixture prototypes by self-adaptation. RelationNet [40] adopts a distance metric network to learn pointwise relations in support and query samples.

Figure 2: The overall framework of DPGN. In this illustration, we take a 2way-1shot task as an example. The support and query embeddings obtained from feature extractor are delivered to the dual complete graph (a point graph and a distribution graph) for transductive propagation generation after generation. The green arrow represents a edge-to-node transformation (, described in Section 3.2.1) which aggregates instance similarities to construct distribution representations and the blue arrow represents another edge-to-node transformation (, described in Section 3.2.2) which aggregates distribution similarities with instance features. DPGN makes the prediction for the query sample at the end of generation .

2.3 Distribution Learning

Distribution Learning theory was first introduced in [17] to find an efficient algorithm that determines the distribution from which the samples are drawn. Different methods [16, 5, 6]

are proposed to efficiently estimate the target distributions. DLDL 

[9] is one of the researches that has assigned the discrete distribution instead of one-hot label for each instance in classification and regression tasks. CPNN [44] takes both features and labels as the inputs and produces the label distribution with only one hidden layer in its framework. LDLFs [35]

devises a distribution learning method based on the decision tree algorithm.

2.4 Meta Learning

Some few-shot approaches adopt a meta-learning framework that learns meta-level knowledge across batches of tasks. MAML [8] are gradient-based approaches that design the meta-learner as an optimizer that could learn to update the model parameters (e.g., all layers of a deep network) within few optimization steps given novel examples. Reptile [28] simplifies the computation of meta-loss by incorporating an L2 loss which updates the meta-model parameters towards the instance-specific adapted models. SNAIL [27] learn a parameterized predictor to estimate the parameters in models. MetaOptNet [21] advocates the use of linear classifier instead of nearest-neighbor methods which can be optimized as convex learning problems. LEO [33] utilizes an encoder-decoder architecture to mine the latent generative representations and predicts high-dimensional parameters in extreme low-data regimes.

3 Method

In this section, we first provide the background of few-shot learning task, then introduce the proposed algorithm in detail.

3.1 Problem Definition

The goal of few-shot learning tasks is to train a model that can perform well in the case where only few samples are given.

Each few-shot task has a support set and a query set . Given training data , the support set contains classes with samples for each class (i.e., the -way -shot setting), it can be denoted as . The query set has samples and can be denoted as . Specifically, in the training stage, data labels are provided for both support set and query set . Given testing data , our goal is to train a classifier that can map the query sample from to the corresponding label accurately with few support samples from . Labels of support sets and query sets are mutually exclusive.

3.2 Distribution Propagation Graph Networks

In this section, we will explain the DPGN that we proposed for few-shot learning in detail. As shown in Figure 2. The DPGN consists of generations and each generation consists of a point graph and a distribution graph . Firstly, the feature embeddings of all samples are extracted by a convolutional backbone, these embeddings are used to compute the instance similarities . Secondly, the instance relations are delivered to construct the distribution graph . The node features are initialized by aggregating following the position order in and the edge features stand for the distribution similarities between the node features . Finally, the obtained is delivered to for constructing more discriminative representations of nodes and we repeat the above procedure generation by generation. A brief introduction of generation update for the DPGN can be expressed as , where denotes the -th generation.

For further explanation, we formulate , , and as follows: , , , where . denotes the total number of examples in a training episode. is first initialized by the output of the feature extractor . For each sample :

(1)

where and denotes the dimension of the feature embedding.

Figure 3: Details about P2D aggregation and D2P aggregation

in DPGN. A 2way-1shot task is presented as an example. MLP-1 is the FC-ReLU blocks mentioned in P2D Aggregation and MLP-2 is the Conv-BN-ReLU blocks mentioned in D2P Aggregation. The green arrow denotes the P2D aggregation while the blue arrow denotes the D2P aggregation. Both aggregation processes integrate the node or edge features of their previous generation.

3.2.1 Point-to-Distribution Aggregation

Point Similarity

Each edge in the point graph stands for the instance (point) similarity and the edge of the first generation is initialized as follows:

(2)

where . is the encoding network that transforms the instance similarity to a certain scale. contains two Conv-BN-ReLU [13, 15] blocks with the parameter set and a sigmoid layer.

For generation , given , and , can be updated as follows:

(3)

In order to use edge information with a holistic view of the graph , a normalization operation is conducted on the .

P2D Aggregation

After edge features in point graph are produced or updated, the distribution graph ) is the next to be constructed. As shown in Figure 3, aims at integrating instance relations from the point graph and process the distribution-level relations. Each distribution feature in is a

dimension feature vector where the value in

-th entry represents the relation between sample and sample and stands for the total number of support samples in a task. For first initialization:

(4)

where and is the concatenation operator. is the Kronecker delta function which outputs one when and zero otherwise ( and are labels).

For generations , the distribution node can be updated as follows:

(5)

where is the aggregation network for distribution graph. applies a concatenation operation between two features. Then, P2D performs a transformation on the concatenated features which is composed of a fully-connected layer and ReLU [13], with the parameter set .

3.2.2 Distribution-to-Point Aggregation

Distribution Similarity

Each edge in distribution graph stands for the similarity between distribution features of different samples. For generation , the distribution similarity is initialized as follows:

(6)

where . The encoding network transforms the distribution similarity using two Conv-BN-ReLU blocks with the parameter set and a sigmoid layer in the end. For generation , the update rule for in is formulated as follows:

(7)

Also, we apply a normalization to .

D2P Aggregation

As illustrated in Figure 3, the encoded distribution information in flows back into the point graph at the end of each generation. Then node features in captures the distribution relations through aggregating all the node features in with edge features as follows:

(8)

where and is the aggregation network for point graph in with the parameter set . concatenates the feature which is computed by with the node features in previous generation and update the concatenated feature with two Conv-BN-ReLU blocks. After this process, the node features can integrate the distribution-level information into the instance-level feature and prepares for computing instance similarities in the next generation.

3.3 Objective

The class prediction of each node can be computed by feeding the corresponding edges in the final generation of DPGN into softmax function:

(9)

where

is the probability distribution over classes given sample

, and is the label of th sample in the support set. stands for the edge feature in the point graph at the final generation.

Point Loss

It is noted that we make classification predictions in the point graph for each sample. Therefore, the point loss at generation is defined as follows:

(10)

where

is the cross-entropy loss function,

stands for the number of samples in each task . and are model probability predictions of sample and the ground-truth label respectively.

Distribution Loss

To facilitate the training process and learn discriminative distribution features , we incorporate the distribution loss which plays a significant role in contributing to faster and better convergence. We define the distribution loss for generation as follows:

(11)

where stands for the edge feature in the distribution graph at generation .

The total objective function is a weighted summation of all the losses mentioned above:

(12)

where denotes total generations of DPGN and the weights and of each loss are set to balance their importance. In most of our experiments, and are set to 1.0 and 0.1 respectively.

4 Experiments

4.1 Datasets and Setups

4.1.1 Datesets

We evaluate DPGN on four standard few-shot learning benchmarks: miniImageNet [41], tieredImageNet [31], CUB-200-2011 [42] and CIFAR-FS [3]. The miniImageNet and tieredImageNet are the subsets of ImageNet [32]. CUB-200-2011 is initially designed for fine-grained classification and CIFAR-FS is a subset of CIFAR-100 for few-shot classification. As shown in Table 1, we list details for images number, classes number, images resolution and train/val/test splits following the criteria of previous works [41, 31, 4, 3].

Dataset Images Classes Train-val-test Resolution
miniImageNet 60000 100 64/16/20 84x84
tieredImageNet 779165 608 351/97/160 84x84
CUB-200-2011 11788 200 100/50/50 84x84
CIFAR-FS 60000 100 64/16/20 32x32
Table 1: Details for few-shot learning benchmarks.

4.1.2 Experiment Setups

Network Architecture

We use four popular networks for fair comparison, which are ConvNet, ResNet12, ResNet18 and WRN that are used in EGNN [18], MetaOptNet [21], CloserLook [4] and LEO [33] respectively. ConvNet mainly consists of four Conv-BN-ReLU blocks. The last two blocks also contain a dropout layer [37]. ResNet12 and ResNet18 are the same as the one described in [14]. They mainly have four blocks, which include one residual block for ResNet12 and two residual blocks for ResNet18 respectively. WRN was firstly proposed in [46]. It mainly has three residual blocks and the depth of the network is set to 28 as in [33]

. The last features of all backbone networks are processed by a global average pooling, then followed by a fully-connected layer with batch normalization

[15] to obtain a 128-dimensions instance embedding.

Training Schema

We perform data augmentation before training, such as horizontal flip, random crop, and color jitter (brightness, contrast, and saturation), which are mentioned in [11, 43]. We randomly sample 28 meta-task episodes in each iteration for meta-training. The Adam optimizer is used in all experiments with the initial learning rate of . We decay the learning rate by 0.1 per 15000 iterations and set the weight decay to .

Evaluation Protocols

We evaluate DPGN in 5way-1shot/5shot settings on standard few-shot learning datasets, miniImageNet, tieredImageNet, CUB-200-2011 and CIFAR-FS. We follow the evaluation process of previous approaches [18, 33, 43]

. We randomly sample 10,000 tasks then report the mean accuracy (in %) as well as the 95% confidence interval.

4.2 Experiment Results

Main Results

We compare the performance of DPGN with several state-of-the-art models including graph and non-graph methods. For fair comparisons, we employ DPGN on miniImageNet, tieredImageNet, CIFAR-FS and CUB-200-2011 datasets, which is compared with other methods in the same backbones. As shown in Table 2, 3 and 4, the proposed DPGN is superior to other existing methods and achieves the state-of-the-art performance, especially compared with the graph-based methods.

Method Backbone 5way-1shot 5way-5shot
MatchingNet [41] ConvNet 43.560.84 55.31 0.73
ProtoNet [36] ConvNet 49.420.78 68.200.66
RelationNet [40] ConvNet 50.440.82 65.320.70
R2D2 [3] ConvNet 51.200.60 68.200.60
MAML [8] ConvNet 48.701.84 55.310.73
Dynamic [11] ConvNet 56.200.86 71.940.57
GNN [10] ConvNet 50.330.36 66.410.63
TPN [25] ConvNet 55.510.86 69.860.65
Global [26] ConvNet 53.210.40 72.340.32
Edge-label [18] ConvNet 59.630.52 76.340.48
DPGN ConvNet 66.010.36 82.830.41
LEO [33] WRN 61.760.08 77.590.12
wDAE [12] WRN 61.070.15 76.750.11
DPGN WRN 67.240.51 83.720.44
CloserLook [4] ResNet18 51.750.80 74.270.63
CTM [22] ResNet18 62.050.55 78.630.06
DPGN ResNet18 66.630.51 84.070.42
MetaGAN [47] ResNet12 52.710.64 68.630.67
SNAIL [27] ResNet12 55.710.99 68.880.92
TADAM [29] ResNet12 58.500.30 76.700.30
Shot-Free [30] ResNet12 59.040.43 77.64
Meta-Transfer [39] ResNet12 61.201.80 75.530.80
FEAT [43] ResNet12 62.960.02 78.490.02
TapNet [45] ResNet12 61.650.15 76.360.10
Dense [24] ResNet12 62.530.19 78.950.13
MetaOptNet [21] ResNet12 62.640.61 78.630.46
DPGN ResNet12 67.770.32 84.600.43
Table 2: Few-shot classification accuracies on miniImageNet. denotes thatit is implemented by public code. [10, 25, 18] and DPGN are tested in transduction.
Method backbone 5way-1shot 5way-5shot
MAML* [8] ConvNet 51.671.81 70.301.75
ProtoNet* [36] ConvNet 53.340.89 72.690.74
RelationNet* [40] ConvNet 54.480.93 71.320.78
TPN [25] ConvNet 59.910.94 73.300.75
Edge-label [18] ConvNet 63.520.52 80.240.49
DPGN ConvNet 69.430.49 85.920.42
CTM [22] ResNet18 64.780.11 81.050.52
DPGN ResNet18 70.460.52 86.440.41
TapNet [45] ResNet12 63.080.15 80.260.12
Meta-Transfer [39] ResNet12 65.621.80 80.610.90
MetaOptNet [21] ResNet12 65.810.74 81.750.53
Shot-Free [30] ResNet12 66.870.43 82.640.39
DPGN ResNet12 72.450.51 87.240.39
Table 3: Few-shot classification accuracies on tieredImageNet. denotes that it is implemented by public code. * denotes that it is reported from [21]. [25, 18] and DPGN are tested in transduction.
Method backbone CUB-200-2011
5way-1shot 5way-5shot
ProtoNet* [36] ConvNet 51.310.91 70.770.69
MAML* [8] ConvNet 55.920.95 72.090.76
MatchingNet* [41] ConvNet 61.160.89 72.860.70
RelationNet* [40] ConvNet 62.450.98 76.110.69
CloserLook [4] ConvNet 60.530.83 79.340.61
DN4 [23] ConvNet 53.150.84 81.900.60
DPGN ConvNet 76.050.51 89.080.38
FEAT [43] ResNet12 68.870.22 82.900.15
DPGN ResNet12 75.710.47 91.480.33
Method backbone CIFAR-FS
5way-1shot 5way-5shot
ProtoNet* [36] ConvNet 55.50.7 72.00.6
MAML* [8] ConvNet 58.91.9 71.51.0
RelationNet* [40] ConvNet 55.01.0 69.30.8
R2D2 [3] ConvNet 65.30.2 79.40.1
DPGN ConvNet 76.40.5 88.40.4
Shot-Free [30] ResNet12 69.20.4 84.70.4
MetaOptNet [21] ResNet12 72.00.7 84.20.5
DPGN ResNet12 77.90.5 90.20.4
Table 4: Few-shot classification accuracies on CUB-200-2011 and CIFAR-FS. * denotes that it is reported from [21] or [4]. DPGN are tested in transduction.
Figure 4: Semi-supervised few-shot learning accuracy in 5way-10shot on miniImageNet. DPGN surpass TPN and EGNN by a large margin consistently.
Semi-supervised Few-shot Learning

We employ DPGN on semi-supervised few-shot learning. Following [25, 18], we use the same criteria to split miniImageNet dataset into labeled and unlabeled parts with different ratios. For a 20% labeled semi-supervised scenario, we split the support samples with a ratio of 0.2/0.8 for labeled and unlabeled data in each class. In semi-supervised few-shot learning, DPGN uses unlabeled support samples to explicitly construct similarity distributions over all other samples and the distributions work as a connection between queries and labeled support samples, which could propagate label information from labeled samples to queries sufficiently.

Method Transduction 5way-5shot
Reptile [28] No 62.74
GNN [10] No 66.41
Edge-label [18] No 66.85
DPGN No 72.83
MAML [8] BN 63.11
Reptile [28] BN 65.99
RelationNet [40] BN 67.07
MAML [8] Yes 66.19
TPN [25] Yes 69.86
Edge-label [18] Yes 76.37
DPGN Yes 84.62
Table 5: Trasductive/non-transductive experiments on miniImageNet. “BN” means information is shared among test examples using batch normalization. denotes that it is implemented by public code released by authors.

In Figure 4, DPGN shows the superiority to exsisting semi-supervised few-shot methods and the result demonstrates the effectiveness to exploit the relations between labeled and unlabeled data when the label ratio decreases. Notably, DPGN surpasses TPN [25] and EGNN [18] by 11% 16% and 7% 13% respectively in few-shot average classification accuracy on miniImageNet.

Figure 5: High-way few-shot classification accuracies on miniImageNet.
Transductive Propagation

To validate the effectiveness of the transductive setting in our framework, we conduct the transductive and non-transductive experiments on miniImageNet dataset in 5way-5shot setting. Table 5 shows that the accuracy of DPGN increases by a large margin in the transductive setting (comparing with non-transductive setting). Compared to TPN and EGNN which consider instance-level features only, DPGN utilizes distribution similarities between query samples and adopts dual graph architecture to propagate label information in a sufficient way.

High-way classification

Furthermore, the performance of DPGN in high-way few-shot scenarios is evaluated on miniImageNet dataset and its results are shown in Figure 5. The observed results show that DPGN not only exceeds the powerful graph-based methods [25, 18] but also surpasses the state-of-the-art non-graph methods significantly. As the number of ways increasing in few-shot tasks, it can broaden the horizons of distribution utilization and make it possible for DPGN to collect more abundant distribution-level information for queries.

4.3 Ablation Studies

Figure 6: Effectiveness of through keeping n dimensions in 5way-1shot on miniImageNet.
Figure 7: Generation number in DPGN on miniImageNet, tieredImageNet, CUB-200-2011 and CIFAR-FS.
Impact of Distribution Graph

The distribution graph works as an important component of DPGN by propagating distribution information, so it is necessary to investigate the effectiveness of quantitatively. We design the experiment by limiting the distribution similarities which flow to for performing aggregation in each generation during the inference process. Specifically, we mask out the edge features through keeping a different number of feature dimensions and set the value of rest dimensions to zero, since zero gives no contribution. Figure 6 shows the result for our experiment in 5way-1shot on miniImageNet. It is obvious that test accuracy and the number of feature dimensions kept in have positive correlations and accuracy increment (area in blue) decreases with more feature dimensions. Keeping dimensions from to , DPGN boosts the performance nearly by 10% in absolute value and the result shows that the distribution graph has a great impact on our framework.

Figure 8: The visualization of edge prediction in each generation of DPGN. (a) to (f) denotes generation 1 to 6. The dark denotes higher score and the shallow denotes lower confidence. The left axis stands for the index of 5 query images and the bottom axis stands for 5 support class.
Generation Numbers

DPGN has a cyclic architecture that includes point graph and distribution graph, each graph has node-update and edge-update modules respectively. The total number of generations is an important ingredient for DPGN, so we perform experiments to obtain the trend of test accuracy with different generation numbers in DPGN on miniImageNet, tieredImageNet, CUB-200-2011, and CIFAR-FS. In Figure 7, with the generation number changing from 0 to 1, the test accuracy has a significant rise. When the generation number changes from 1 to 10, the test accuracy increases by a small margin and the curve becomes to fluctuate in the last several generations. Considering that more generations need more iterations to converge, we choose generation 6 as a trade-off between the test accuracy and convergence time. Additionally, to visualize the procedure of cyclic update, we choose a test scenario where the ground truth classes of five query images are [1, 2, 3, 4, 5] and visualize instance-level similarities which is used for predictions of five query samples as shown in Figure 8. The heatmap shows DPGN refines the instance-level similarity matrix after several generations and makes the right predictions for five query samples in the final generation. Notably, DPGN not only contributes to predicting more accurately but also enlarge the similarity distances between the samples in different classes through making instance features more discriminative, which cleans the prediction heatmap.

5 Conclusion

In this paper, we have presented the Distribution Propagation Graph Network for few-shot learning, a dual complete graph network that combines instance-level and distribution-level relations in an explicit way equipped with label propagation and transduction. The point and distribution losses are used to jointly update the parameters of the DPGN with episodic training. Extensive experiments demonstrate that our method outperforms recent state-of-the-art algorithms by 5%12% in the supervised task and 7%13% in semi-supervised task on few-shot learning benchmarks. For future work, we aim to focus on the high-order message propagation through encoding more complicated information which is linked with task-level relations.

6 Acknowledgement

This research was supported by National Key R&D Program of China (No. 2017YFA0700800).

References