Semi-Supervised Few-Shot Learning with Prototypical Networks

11/29/2017 ∙ by Rinu Boney, et al. ∙ The Curious AI Company 0

We consider the problem of semi-supervised few-shot classification (when the few labeled samples are accompanied with unlabeled data) and show how to adapt the Prototypical Networks to this problem. We first show that using larger and better regularized prototypical networks can improve the classification accuracy. We then show further improvements by making use of unlabeled data.



There are no comments yet.


page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Few-shot learning addresses the problem of learning new concepts quickly, which is one of the important properties of human intelligence. In few-shot classification, the task is to adapt a classifier to previously unseen classes from just a few examples (Lake et al., 2015; Koch et al., 2015; Vinyals et al., 2016; Ravi and Larochelle, 2016). This skill is useful in many practical applications since data annotation is laborious and a training set can be rather small. Recent approaches have focused on casting few-shot learning as a meta-learning problem, in which a model is trained on a variety of tasks to adapt quickly using a small number of training examples (Finn et al., 2017a; Munkhdalai and Yu, 2017; Ravi and Larochelle, 2016; Li et al., 2017). Thus, the system learns to adapt to new problems using few labeled samples based on the knowledge transferred from its previous experience with related problems.

In some real-world problems, the tasks to which the learning system needs to adapt may contain both (few) labeled and (many) unlabeled examples, the problem known as semi-supervised learning. For example, a common feature of photo management applications is the automatic organization of images based on limited interactive supervision from the user. The classes relevant to a specific user are likely to be different from the classes in publicly available image datasets such as ImageNet, thus there is a need for adaptation. The user can facilitate the adaptation by labeling a few images by personal preferences. In this application, the learning system also has access to lots of unlabeled images and it can make use of those to improve the classification accuracy. This is the problem of semi-supervised few-shot classification that we consider.

In this paper, we show how to tackle this problem using a model called Prototypical Networks (PN) (Snell et al., 2017). Using the observation that PN tends to produce clustered data representations, we cast the semi-supervised few-shot problem as a semi-supervised clustering problem and propose several approaches to solve it. We also take inspiration from (Cohn et al., 2003)

to enable adaptation to new classification tasks using feedback from the user. We argue that this approach can be practical in many real-world applications, as many use cases of semi-supervised few-shot adaptation imply interaction with a user and therefore active learning is often possible.

We use the following formulation of the few-shot classification problem. There is a training set which consists of a large set of classes and we have access to (potentially many) labeled samples from each class in the training set. At test time, the task is to separate previously unseen classes using a small number of labeled samples and (potentially many) unlabeled samples from the same classes. This is called an -way -shot classification task. We follow the recent literature and use the episodic regime of training and evaluation, as we explain in the following section.

2 Prototypical Networks

The episodic training of PN iterates between the following steps. A subset of classes is randomly selected to formulate one training task. For each training task, a support set and a query set are created by sampling examples from the selected classes, where are inputs and are the corresponding labels.

Prototypical Networks compute representations of the inputs using an embedding function parameterized with : . Each class

is represented in the embedding space by a prototype vector which is computed as the mean vector of the embedded inputs for all the examples

of the corresponding class :


The distribution over predicted labels for a new sample is computed using softmax over negative distances to the prototypes in the embedding space:


Parameters are updated so as to improve the likelihood computed on the query set:

which is computed using (2

) with the estimated prototypes.

3 Adaptation Using Unlabeled Data

3.1 Semi-Supervised Few-Shot Adaptation

In the semi-supervised scenario, we need to adapt to tasks which contain both labeled and unlabeled samples. Classical methods of semi-supervised learning (Chapelle et al., 2006)

are based on making certain assumptions such as the smoothness assumption (the label function is smoother in high-density data regions than in low-density regions), the cluster assumption (points in the same cluster are likely to be in the same class), the manifold assumption (the data lie on a low-dimensional manifold). Many recent algorithms in the deep learning literature are built around the smoothness and the manifold assumptions

(see, e.g., Rasmus et al., 2015; Miyato et al., 2015; Laine and Aila, 2016; Tarvainen and Valpola, 2017) which are induced by extra terms in the objective function which depends on unlabeled data.

In this paper, we address semi-supervised classification using the cluster assumption. The motivation for this comes from the observation that PN tends to produce clustered data representations in the embedding space. This is induced by the PN decision rule (2) which uses the distances of a sample to the class means. The clustering approach to semi-supervised learning is often referred as semi-supervised clustering (Basu et al., 2002).

Our proposed algorithm works as follows. PN is trained with a standard training procedure which involves sampling of tasks, computing the prototypes of each class and updating parameters

of the embedding network using stochastic gradient descent, where the gradient is computed using samples of the query set. At test time, the prototypes are first estimated with the labeled data using (

1). We then perform the standard -means algorithm (Lloyd, 1982) on the embeddings of both labeled and unlabeled data initializing the cluster means with the prototypes computed from the labeled data. This corresponds to the seeding approach of -means proposed by Basu et al. (2002). The algorithm typically converges in just a few iterations (see Table 5). We have also tried the constrained -means approach in which the cluster membership of the labeled examples is never changed. Both approaches yielded similar results and therefore we present only the results of the seeding approach which was slightly better.

The proposed algorithm is related to the one concurrently developed by Ren et al. (2018) who do semi-supervised few-shot classification using the constrained -means clustering. The main difference is that they perform clustering also at training time and therefore they use soft cluster assignments to keep the computational graph differentiable. They do only one iteration of -means, as doing more iterations does not improve the performance. We obtained similar results using soft cluster assignment and found experimentally that -means with hard clustering behaves more robustly (see results in Table 5). Furthermore, Ren et al. (2018) consider the scenario in which the unlabeled support set may contain samples from irrelevant classes.

Figure 1: Left: Example task from the sine data set. The black dots correspond to samples of one class and the white dots correspond to samples from the other class. The optimal decision boundary is shown with the dashed blue line. Middle: Example of adaptation to a test task from 10 labeled samples. Right: Example of adaptation to a test task from the same 10 labeled samples and 100 unlabeled samples. The blue dots correspond to unlabeled samples.

3.2 Unsupervised Adaptation

In many practical applications, the classification tasks to which a classifier needs to adapt share the same output space. As an example, consider a material recognition system deployed in multiple factories. The same categories of materials need to be recognized at different sites, however, factories may have slightly different lighting conditions or other factors affecting the recognition process, motivating the need for adaptation of each recognition system. In these scenarios, since the classes do not change across tasks, the average class prototypes can be computed at the end of training by averaging the prototypes of the corresponding class from all the training tasks. At test time, we cluster the samples using -means and then assign each cluster to the class of the closest prototype.

3.3 Experiments with Synthetic Data

To test the proposed algorithm, we created a synthetic data set which is fast to experiment with. The dataset consists of a set of two-dimensional classification tasks with two classes, in which the optimal decision boundary is a sine wave (see Fig. 1). The amplitude of the optimal decision boundary varies across tasks within and the phase varies within . The first dimension of the data samples is drawn uniformly from and the second dimension is computed as

where is a noise term with the Laplace distribution with the mean (depending on the class) and the scale parameter 0.5. We sampled 100 tasks for training and 1000 tasks for testing.

In this experiment, we use a fully connected network with two hidden layers of size 40 with ReLU nonlinearity as the embedding network. Examples of the decision boundaries produced by PN on test tasks are shown in Fig. 

1. Note that even for a small number of training examples the PN decision boundaries resemble the sine wave, thus the knowledge is transferred between tasks. The classification errors of semi-supervised adaptation on test tasks depending on the number of unlabeled samples are shown in Table 2. One can see from the results that the PN accuracy improves with the use of unlabeled data. However, the improvement plateaus and the results do not improve much after adding more than 100 data points.

We also experimented with the proposed unsupervised adaptation method on the synthetic dataset. The classification errors as a function of the number of unlabeled samples are shown in Table 2.

labeled unlabeled
0 6.38 6.38
10 4.97 5.54
100 2.73 4.70
1000 1.68 4.57
Table 2: Fully-supervised and unsupervised adaptation on sine data set: Classification errors obtained with labeled or unlabeled samples.
supervised unsupervised
10 6.38 7.71
100 2.98 6.40
1000 1.99 6.39
Table 1: Semi-supervised adaptation on sine data set: Classification errors obtained with 10 labeled samples and extra samples (either labeled or unlabeled).

3.4 Experiments with miniImagenet

We tested the proposed method in the semi-supervised scenario on the miniImagenet recognition task proposed by Vinyals et al. (2016). The dataset involves downsampled 84x84 images from 64 training classes, 12 validation classes, and 24 test classes from ImageNet. We use the same split as Ravi and Larochelle (2016) and follow the experimental setup which involves -way -shot classification, similarly to Vinyals et al. (2016) that is every task at test time contains classes with labeled examples from each class. For the semi-supervised case, we also assume the existence of unlabeled samples per class at test time. Following previous works, we use 15 test samples per class for each task during training and for evaluation. We evaluate the model over 2400 tasks which involve the 24 classes reserved for testing.

One challenge with the miniImagenet data set is that it requires rather complex features but it contains a relatively little amount of data. Therefore, preventing overfitting becomes an important issue. Most previous works Vinyals et al. (2016); Ravi and Larochelle (2016); Finn et al. (2017b); Snell et al. (2017) used conventional convolutional networks with a small number of layers to prevent overfitting. For comparability, we use the same architecture which consists of four blocks. Each block consists of a

convolutional layer with 64 channels followed by batch normalization, ReLU non-linearity and a

max-pooling layer which results in an embedding space of dimensionality 1600. In Prototypical Networks Snell et al. (2017), the model was fine-tuned by having more classes at training time than at test time (e.g., 30-way training and 5-way testing) and a learning rate decay schedule. We use a constant learning rate and 5-way training for simplicity. This simple network may not be expressive enough to capture relevant complex features. So, we also consider a Wide Residual Network Zagoruyko and Komodakis (2016) as the embedding network. We use a network of depth 16 and a widening factor of 6. We also use a

pooling with a stride of 4 at the end to obtain embeddings of dimensionality 384. The network is regularized with dropout with a rate of

. We use the Adam optimizer Kingma and Ba (2014) with a learning rate of 0.01 for training the ResNet and 0.001 for training the four-block architecture. We perform early stopping to prevent overfitting.111

Residual Networks are typically trained using stochastic gradient descent with momentum and we expect better results by doing the same and fine-tuning the hyperparameters.

We train the ResNet model with classes in each task for 1-shot classification and classes in each task for 5-shot classification, similarly to the original PN paper. For the ResNet model, we tried varying the number of classes during training and observed that it had a much smaller impact on the results compared to the observations in Snell et al. (2017). This suggests that tuning this regularization parameter is less important for this architecture.

The results presented in Table 3 show that our approach scales to both feature extractor architectures. One can observe that using the Wide ResNets to learn the embedding space yields noticeable improvements in the classification accuracy compared to the baseline methods. The improvements are more significant for the 20-way classification. Table 3 presents the results of the PN tested in the semi-supervised scenario: The prototypes are re-estimated at test time using both labeled samples and the inputs from the query set.

In Table 4, we show how the number of unlabeled examples at test time affects the classification accuracy of the trained PN. The results indicate that more unlabeled samples yield better performance, however, the improvement plateaus very quickly with the increase of the number of unlabeled samples. This agrees with the results obtained on the synthetic data reported in Section 3.3. Notably, the classification performance of the four-block architecture scales well with increasing the number of unlabeled samples, closely matching the performance of the ResNet in the case of 120 unlabeled samples per class. The evolution of the classification accuracy with increasing the number of -means iterations is shown in Table 5.

5-way testing 20-way testing
Model 1-shot 5-shot 1-shot 5-shot
fine-tuning baseline
nearest-neighbor baseline
Meta-LSTM Ravi and Larochelle (2016)
Matching nets Vinyals et al. (2016)
MAML Finn et al. (2017b)
ARC Shyam et al. (2017)
Meta Networks Munkhdalai and Yu (2017)
PN Snell et al. (2017)
Meta-SGD Li et al. (2017)
PN (ours)
PN, re-estimate (ours)
Resnet PN (ours)
Resnet PN, re-estimate (ours)
Table 3:

Average classification accuracy (with 95% confidence intervals) on miniImagenet. The comparison numbers for the 20-way testing are from

Li et al. (2017).
5 way testing
1-shot 5-shot
PN (ours) 15
Resnet PN 15
Table 4: Average classification accuracy (with 95% confidence intervals) on miniImagenet for the semi-supervised scenario for different number of unlabeled samples per class () available at test time.
Hard -means Soft -means
PN (ours) 0
Resnet PN 0
Table 5: Average classification accuracy (with 95% confidence intervals) on miniImagenet for the semi-supervised scenario as a function of the number of -means iterations. Each task consists of 1 labeled sample and 15 unlabeled samples.

4 Active Few-Shot Adaptation

There are two sources of errors which the semi-supervised adaptation algorithm proposed in Section 3.1

can accumulate: 1) errors due to incorrect clustering of data, 2) errors due to incorrect labeling of the clusters. The second type of errors can occur when the few labeled examples are outliers which end up closer to the prototype of another class in the embedding space. In this paper, we advocate that the most practical way to correct the second type of errors can be through user feedback, since in many applications of the semi-supervised few-shot adaptation, interaction with the user is possible. This idea is inspired by the work of

Cohn et al. (2003) who introduced a clustering approach that allows a user to iteratively provide feedback to a clustering algorithm.

Consider the previously introduced example of few-shot learning in photo management applications. Although it is possible to ask the user to label a few photographs and use those labels to classify the rest of the pictures, it is extremely difficult and tiresome for the user to scroll through all the photos and decide which samples should be labeled. Instead, using the observation that “It is easier to criticize than to create” (Cohn et al., 2003), one can initially cluster the photos and then request the user to label certain photos (or provide other types of feedback) so that the data are properly clustered and labeled. The user can provide feedback in various forms and therefore can effectively introduce various constraints that can further guide the clustering process. For example, a user can assign the whole cluster to a particular class, assign a sample to a particular cluster, mark that a particular sample does not belong to the assigned cluster, split and combine clusters. These constraints could be easily induced in basic clustering algorithms such as -means. For examples, Wagstaff et al. (2001) introduced constraints between samples in the data set such as must-link (two samples have to be in the same cluster) and cannot-link (two samples have to be in different clusters) and the clustering algorithm finds a solution that satisfies all the constraints.

Even outside the context of few-shot learning, this active learning approach can be used to adapt a pre-trained classifier. Assume that we have a classifier that clusters the classes of a particular classification task such as ImageNet. Then, during test time it is possible to interactively split clusters to make coarse-grained classifications or to assign multiple clusters to a super-cluster (to make hierarchical predictions).

In this paper, we assume that the user can provide feedback only in the form of labeling a particular sample or labeling the whole cluster. We propose to use PN as a feature extractor, cluster the samples in the embedding space using -means and then label the clusters by requesting one labeled example for each cluster from the user. For each cluster , we choose sample to be labeled by the user by maximizing an acquisition function :

where is the set of embedded inputs belonging to cluster . We explore a few acquisition functions:

  • Random: Sample a data point uniformly at random from each cluster. This is a baseline approach.

  • Nearest: Select the data point which is closest to the cluster center:

    where is the mean (cluster center) of cluster .

  • Entropy: Select the sample with the least entropy:

    Thus, we select a sample with the least uncertainty that it belongs to a certain cluster.

  • Margin: Select a sample with the largest margin between the most likely and second most likely labels.

    where and are the most likely and the second most likely clusters of embedded input respectively. This quantity was proposed as a measure of uncertainty by Scheffer et al. (2001).

We also try to simulate a case when the user can label the whole cluster, as in some applications it can certainly be possible. This approach directly measures the clustering accuracy and we call it “oracle”.

  • Oracle: We label each cluster based on the distance of the cluster mean to the prototypes computed from the true labels of all the samples.

4.1 Experiments with miniImagenet

In the experiments with miniImagenet we use a PN trained in the episodic mode as the feature extractor. We simulate active learning on test tasks by first doing -means clustering in the PN embedding space and then requesting one labeled example for each cluster using the acquisition functions described earlier. Note that multiple clusters can be labeled to the same class if the requested labels guide it that way. This is the largest source of error (see also Fig. 2). Table 6 presents the classification performance of each strategy for test tasks with one labeled sample and a varying number of unlabeled samples. There, we also present the accuracy of the oracle clustering. Overall, the margin approach worked best in our experiments. The 1-shot classification accuracy with 120 unlabeled samples per class even surpassed the 5-shot accuracy of some well-recognized previous methods. Similar to the semi-supervised scenario, the four-block architecture scales well with increasing the number of unlabeled samples closely matching the performance of the ResNet in the case of 120 unlabeled samples per class and even outperforming it while using the margin strategy.

Random Nearest Entropy Margin Oracle
PN (ours) 15
Resnet PN 15
Table 6: Average 1-shot classification accuracy (with 95% confidence intervals) on miniImagenet for the active learning scenario for different number of unlabeled samples per class () available at test time.
True Labels Optimal Supervised Semi-supervised Active
Figure 2: Illustration of 3-way 1-shot miniImagenet tasks (best viewed in color). The two-dimensional visualizations are obtained by projecting the 384-dimensional features onto the principal subspace of the prototypes computed from the support set. The support set is represented using triangle markers, the query set using circle markers and the prototypes using star markers. The colors represent the true labels in the first column and the labels produced with different adaptation strategies in columns 2–5. Each row corresponds to one 3-way 1-shot task from miniImagenet. True Labels: The true labels. Optimal: Predictions based on the prototype computed from the true labels of all the samples. Supervised: Predictions based on the prototypes computed from the support set. Semi-supervised: Predictions based on the prototypes (or cluster means) computed after -means clustering seeded by the prototypes of the support set. Active: Predictions based on the prototypes (or cluster means) computed by -means and labeling the clusters using the Nearest approach. (a)–(b): Tasks are reasonably well clustered and supervised adaptation performs quite well. The labeling is further slightly improved by semi-supervised and active adaptation. (c): The support sample of the red class is an outlier and using it as a prototype leads to misclassifications. Even seeding the clustering with these prototypes (semi-supervised classification) leads to incorrect clustering. However, the active adaptation is able to find a reasonable solution. (d): A failure case of the Nearest approach of the active adaptation where the samples are properly clustered but incorrectly labeled. (e): Samples are not properly clustered in the feature space. The supervised and semi-supervised approaches fail badly, while the active approach produces somewhat usable results.

5 Discussion

In this paper, we extended Prototypical Networks to adapt to new classification tasks in the semi-supervised few-shot learning scenario when a few labeled examples are accompanied with many unlabeled examples from the same classes. We proposed to use the clustering approach to semi-supervised classification when the clustering process is guided by the labeled examples. This is different to recent deep learning papers (Rasmus et al., 2015; Miyato et al., 2015; Laine and Aila, 2016; Tarvainen and Valpola, 2017) which constrain the classifier using unlabeled data. We also advocated that in many real-world applications it can be possible to request the few labeled examples from the user, which can yield better performance.

The proposed solution of semi-supervised few-shot adaptation is based on doing

-means clustering in the embedding space found by Prototypical Networks. These two methods make a good fit because they make similar assumptions about the data distribution: In Prototypical Networks, the distribution of each class is represented by its mean and the variances of class distributions are assumed equal. The same assumptions are made by


The fundamental bottleneck of the proposed approach in improving the classification performance is the ability of the feature extractor to cluster unseen data. Although we used an embedding network trained using Prototypical Networks, the adaptation mechanisms proposed in this paper can be performed using other feature extractors as well. A feature extractor explicitly trained to cluster data can further improve the few-shot classification performance and this is an area of active research (Song et al., 2016, 2017; Law et al., 2017). Building feature extractors that allow better generalization is largely an unsolved problem and it requires further exploration (see, e.g., Sabour et al., 2017; Hinton et al., 2018).


We would like to thank our colleagues from The Curious AI Company for fruitful discussions.


Appendix A Sensitivity to the Number of Shots

Snell et al. (2017) showed that the classification performance of Prototypical Networks is sensitive to the number of classes per task during training and that it was necessary to match the number of samples per class during training and testing. We believe that using a larger number of classes ( way) works as a regularizer. However, varying effectively changes the batch size and therefore the learning rate needs to be adapted to . By tuning the learning rate for different , we observed a smaller effect of this regularization on the adaptation accuracy.

The dependency on implies that the embedding learned by the network does not generalize well to a different number of samples during test time. This is inconvenient especially if the number of the labeled example is unknown in advance and it can grow, for example, as a result of interaction with the user. To address this problem, we propose to use a varying number of samples per class during training for better generalization to the number of shots during test time. The results in Table 7 illustrate that this strategy is effective and reduces the sensitivity to during training.

Training shot Testing shot
1-shot 5-shot
PN (ours) 1
Resnet PN 1
Table 7: Results of fully supervised adaptation using PN as a function of the number of samples per class during training. denotes varying in the range .