Few-shot learning addresses the problem of learning new concepts quickly, which is one of the important properties of human intelligence. In few-shot classification, the task is to adapt a classifier to previously unseen classes from just a few examples (Lake et al., 2015; Koch et al., 2015; Vinyals et al., 2016; Ravi and Larochelle, 2016). This skill is useful in many practical applications since data annotation is laborious and a training set can be rather small. Recent approaches have focused on casting few-shot learning as a meta-learning problem, in which a model is trained on a variety of tasks to adapt quickly using a small number of training examples (Finn et al., 2017a; Munkhdalai and Yu, 2017; Ravi and Larochelle, 2016; Li et al., 2017). Thus, the system learns to adapt to new problems using few labeled samples based on the knowledge transferred from its previous experience with related problems.
In some real-world problems, the tasks to which the learning system needs to adapt may contain both (few) labeled and (many) unlabeled examples, the problem known as semi-supervised learning. For example, a common feature of photo management applications is the automatic organization of images based on limited interactive supervision from the user. The classes relevant to a specific user are likely to be different from the classes in publicly available image datasets such as ImageNet, thus there is a need for adaptation. The user can facilitate the adaptation by labeling a few images by personal preferences. In this application, the learning system also has access to lots of unlabeled images and it can make use of those to improve the classification accuracy. This is the problem of semi-supervised few-shot classification that we consider.
In this paper, we show how to tackle this problem using a model called Prototypical Networks (PN) (Snell et al., 2017). Using the observation that PN tends to produce clustered data representations, we cast the semi-supervised few-shot problem as a semi-supervised clustering problem and propose several approaches to solve it. We also take inspiration from (Cohn et al., 2003)
to enable adaptation to new classification tasks using feedback from the user. We argue that this approach can be practical in many real-world applications, as many use cases of semi-supervised few-shot adaptation imply interaction with a user and therefore active learning is often possible.
We use the following formulation of the few-shot classification problem. There is a training set which consists of a large set of classes and we have access to (potentially many) labeled samples from each class in the training set. At test time, the task is to separate previously unseen classes using a small number of labeled samples and (potentially many) unlabeled samples from the same classes. This is called an -way -shot classification task. We follow the recent literature and use the episodic regime of training and evaluation, as we explain in the following section.
2 Prototypical Networks
The episodic training of PN iterates between the following steps. A subset of classes is randomly selected to formulate one training task. For each training task, a support set and a query set are created by sampling examples from the selected classes, where are inputs and are the corresponding labels.
Prototypical Networks compute representations of the inputs using an embedding function parameterized with : . Each class
is represented in the embedding space by a prototype vector which is computed as the mean vector of the embedded inputs for all the examplesof the corresponding class :
The distribution over predicted labels for a new sample is computed using softmax over negative distances to the prototypes in the embedding space:
3 Adaptation Using Unlabeled Data
3.1 Semi-Supervised Few-Shot Adaptation
In the semi-supervised scenario, we need to adapt to tasks which contain both labeled and unlabeled samples. Classical methods of semi-supervised learning (Chapelle et al., 2006)
are based on making certain assumptions such as the smoothness assumption (the label function is smoother in high-density data regions than in low-density regions), the cluster assumption (points in the same cluster are likely to be in the same class), the manifold assumption (the data lie on a low-dimensional manifold). Many recent algorithms in the deep learning literature are built around the smoothness and the manifold assumptions(see, e.g., Rasmus et al., 2015; Miyato et al., 2015; Laine and Aila, 2016; Tarvainen and Valpola, 2017) which are induced by extra terms in the objective function which depends on unlabeled data.
In this paper, we address semi-supervised classification using the cluster assumption. The motivation for this comes from the observation that PN tends to produce clustered data representations in the embedding space. This is induced by the PN decision rule (2) which uses the distances of a sample to the class means. The clustering approach to semi-supervised learning is often referred as semi-supervised clustering (Basu et al., 2002).
Our proposed algorithm works as follows. PN is trained with a standard training procedure which involves sampling of tasks, computing the prototypes of each class and updating parameters
of the embedding network using stochastic gradient descent, where the gradient is computed using samples of the query set. At test time, the prototypes are first estimated with the labeled data using (1). We then perform the standard -means algorithm (Lloyd, 1982) on the embeddings of both labeled and unlabeled data initializing the cluster means with the prototypes computed from the labeled data. This corresponds to the seeding approach of -means proposed by Basu et al. (2002). The algorithm typically converges in just a few iterations (see Table 5). We have also tried the constrained -means approach in which the cluster membership of the labeled examples is never changed. Both approaches yielded similar results and therefore we present only the results of the seeding approach which was slightly better.
The proposed algorithm is related to the one concurrently developed by Ren et al. (2018) who do semi-supervised few-shot classification using the constrained -means clustering. The main difference is that they perform clustering also at training time and therefore they use soft cluster assignments to keep the computational graph differentiable. They do only one iteration of -means, as doing more iterations does not improve the performance. We obtained similar results using soft cluster assignment and found experimentally that -means with hard clustering behaves more robustly (see results in Table 5). Furthermore, Ren et al. (2018) consider the scenario in which the unlabeled support set may contain samples from irrelevant classes.
3.2 Unsupervised Adaptation
In many practical applications, the classification tasks to which a classifier needs to adapt share the same output space. As an example, consider a material recognition system deployed in multiple factories. The same categories of materials need to be recognized at different sites, however, factories may have slightly different lighting conditions or other factors affecting the recognition process, motivating the need for adaptation of each recognition system. In these scenarios, since the classes do not change across tasks, the average class prototypes can be computed at the end of training by averaging the prototypes of the corresponding class from all the training tasks. At test time, we cluster the samples using -means and then assign each cluster to the class of the closest prototype.
3.3 Experiments with Synthetic Data
To test the proposed algorithm, we created a synthetic data set which is fast to experiment with. The dataset consists of a set of two-dimensional classification tasks with two classes, in which the optimal decision boundary is a sine wave (see Fig. 1). The amplitude of the optimal decision boundary varies across tasks within and the phase varies within . The first dimension of the data samples is drawn uniformly from and the second dimension is computed as
where is a noise term with the Laplace distribution with the mean (depending on the class) and the scale parameter 0.5. We sampled 100 tasks for training and 1000 tasks for testing.
In this experiment, we use a fully connected network with two hidden layers of size 40 with ReLU nonlinearity as the embedding network. Examples of the decision boundaries produced by PN on test tasks are shown in Fig.1. Note that even for a small number of training examples the PN decision boundaries resemble the sine wave, thus the knowledge is transferred between tasks. The classification errors of semi-supervised adaptation on test tasks depending on the number of unlabeled samples are shown in Table 2. One can see from the results that the PN accuracy improves with the use of unlabeled data. However, the improvement plateaus and the results do not improve much after adding more than 100 data points.
We also experimented with the proposed unsupervised adaptation method on the synthetic dataset. The classification errors as a function of the number of unlabeled samples are shown in Table 2.
3.4 Experiments with miniImagenet
We tested the proposed method in the semi-supervised scenario on the miniImagenet recognition task proposed by Vinyals et al. (2016). The dataset involves downsampled 84x84 images from 64 training classes, 12 validation classes, and 24 test classes from ImageNet. We use the same split as Ravi and Larochelle (2016) and follow the experimental setup which involves -way -shot classification, similarly to Vinyals et al. (2016) that is every task at test time contains classes with labeled examples from each class. For the semi-supervised case, we also assume the existence of unlabeled samples per class at test time. Following previous works, we use 15 test samples per class for each task during training and for evaluation. We evaluate the model over 2400 tasks which involve the 24 classes reserved for testing.
One challenge with the miniImagenet data set is that it requires rather complex features but it contains a relatively little amount of data. Therefore, preventing overfitting becomes an important issue. Most previous works Vinyals et al. (2016); Ravi and Larochelle (2016); Finn et al. (2017b); Snell et al. (2017) used conventional convolutional networks with a small number of layers to prevent overfitting. For comparability, we use the same architecture which consists of four blocks. Each block consists of a
convolutional layer with 64 channels followed by batch normalization, ReLU non-linearity and amax-pooling layer which results in an embedding space of dimensionality 1600. In Prototypical Networks Snell et al. (2017), the model was fine-tuned by having more classes at training time than at test time (e.g., 30-way training and 5-way testing) and a learning rate decay schedule. We use a constant learning rate and 5-way training for simplicity. This simple network may not be expressive enough to capture relevant complex features. So, we also consider a Wide Residual Network Zagoruyko and Komodakis (2016) as the embedding network. We use a network of depth 16 and a widening factor of 6. We also use a
pooling with a stride of 4 at the end to obtain embeddings of dimensionality 384. The network is regularized with dropout with a rate of. We use the Adam optimizer Kingma and Ba (2014) with a learning rate of 0.01 for training the ResNet and 0.001 for training the four-block architecture. We perform early stopping to prevent overfitting.111
Residual Networks are typically trained using stochastic gradient descent with momentum and we expect better results by doing the same and fine-tuning the hyperparameters.We train the ResNet model with classes in each task for 1-shot classification and classes in each task for 5-shot classification, similarly to the original PN paper. For the ResNet model, we tried varying the number of classes during training and observed that it had a much smaller impact on the results compared to the observations in Snell et al. (2017). This suggests that tuning this regularization parameter is less important for this architecture.
The results presented in Table 3 show that our approach scales to both feature extractor architectures. One can observe that using the Wide ResNets to learn the embedding space yields noticeable improvements in the classification accuracy compared to the baseline methods. The improvements are more significant for the 20-way classification. Table 3 presents the results of the PN tested in the semi-supervised scenario: The prototypes are re-estimated at test time using both labeled samples and the inputs from the query set.
In Table 4, we show how the number of unlabeled examples at test time affects the classification accuracy of the trained PN. The results indicate that more unlabeled samples yield better performance, however, the improvement plateaus very quickly with the increase of the number of unlabeled samples. This agrees with the results obtained on the synthetic data reported in Section 3.3. Notably, the classification performance of the four-block architecture scales well with increasing the number of unlabeled samples, closely matching the performance of the ResNet in the case of 120 unlabeled samples per class. The evolution of the classification accuracy with increasing the number of -means iterations is shown in Table 5.
|5-way testing||20-way testing|
|Meta-LSTM Ravi and Larochelle (2016)|
|Matching nets Vinyals et al. (2016)|
|MAML Finn et al. (2017b)|
|ARC Shyam et al. (2017)||–||–||–|
|Meta Networks Munkhdalai and Yu (2017)||–||–||–|
|PN Snell et al. (2017)||–||–|
|Meta-SGD Li et al. (2017)|
|PN, re-estimate (ours)|
|Resnet PN (ours)|
|Resnet PN, re-estimate (ours)|
Average classification accuracy (with 95% confidence intervals) on miniImagenet. The comparison numbers for the 20-way testing are fromLi et al. (2017).
|5 way testing|
|Hard -means||Soft -means|
4 Active Few-Shot Adaptation
There are two sources of errors which the semi-supervised adaptation algorithm proposed in Section 3.1
can accumulate: 1) errors due to incorrect clustering of data, 2) errors due to incorrect labeling of the clusters. The second type of errors can occur when the few labeled examples are outliers which end up closer to the prototype of another class in the embedding space. In this paper, we advocate that the most practical way to correct the second type of errors can be through user feedback, since in many applications of the semi-supervised few-shot adaptation, interaction with the user is possible. This idea is inspired by the work ofCohn et al. (2003) who introduced a clustering approach that allows a user to iteratively provide feedback to a clustering algorithm.
Consider the previously introduced example of few-shot learning in photo management applications. Although it is possible to ask the user to label a few photographs and use those labels to classify the rest of the pictures, it is extremely difficult and tiresome for the user to scroll through all the photos and decide which samples should be labeled. Instead, using the observation that “It is easier to criticize than to create” (Cohn et al., 2003), one can initially cluster the photos and then request the user to label certain photos (or provide other types of feedback) so that the data are properly clustered and labeled. The user can provide feedback in various forms and therefore can effectively introduce various constraints that can further guide the clustering process. For example, a user can assign the whole cluster to a particular class, assign a sample to a particular cluster, mark that a particular sample does not belong to the assigned cluster, split and combine clusters. These constraints could be easily induced in basic clustering algorithms such as -means. For examples, Wagstaff et al. (2001) introduced constraints between samples in the data set such as must-link (two samples have to be in the same cluster) and cannot-link (two samples have to be in different clusters) and the clustering algorithm finds a solution that satisfies all the constraints.
Even outside the context of few-shot learning, this active learning approach can be used to adapt a pre-trained classifier. Assume that we have a classifier that clusters the classes of a particular classification task such as ImageNet. Then, during test time it is possible to interactively split clusters to make coarse-grained classifications or to assign multiple clusters to a super-cluster (to make hierarchical predictions).
In this paper, we assume that the user can provide feedback only in the form of labeling a particular sample or labeling the whole cluster. We propose to use PN as a feature extractor, cluster the samples in the embedding space using -means and then label the clusters by requesting one labeled example for each cluster from the user. For each cluster , we choose sample to be labeled by the user by maximizing an acquisition function :
where is the set of embedded inputs belonging to cluster . We explore a few acquisition functions:
Random: Sample a data point uniformly at random from each cluster. This is a baseline approach.
Nearest: Select the data point which is closest to the cluster center:
where is the mean (cluster center) of cluster .
Entropy: Select the sample with the least entropy:
Thus, we select a sample with the least uncertainty that it belongs to a certain cluster.
Margin: Select a sample with the largest margin between the most likely and second most likely labels.
where and are the most likely and the second most likely clusters of embedded input respectively. This quantity was proposed as a measure of uncertainty by Scheffer et al. (2001).
We also try to simulate a case when the user can label the whole cluster, as in some applications it can certainly be possible. This approach directly measures the clustering accuracy and we call it “oracle”.
Oracle: We label each cluster based on the distance of the cluster mean to the prototypes computed from the true labels of all the samples.
4.1 Experiments with miniImagenet
In the experiments with miniImagenet we use a PN trained in the episodic mode as the feature extractor. We simulate active learning on test tasks by first doing -means clustering in the PN embedding space and then requesting one labeled example for each cluster using the acquisition functions described earlier. Note that multiple clusters can be labeled to the same class if the requested labels guide it that way. This is the largest source of error (see also Fig. 2). Table 6 presents the classification performance of each strategy for test tasks with one labeled sample and a varying number of unlabeled samples. There, we also present the accuracy of the oracle clustering. Overall, the margin approach worked best in our experiments. The 1-shot classification accuracy with 120 unlabeled samples per class even surpassed the 5-shot accuracy of some well-recognized previous methods. Similar to the semi-supervised scenario, the four-block architecture scales well with increasing the number of unlabeled samples closely matching the performance of the ResNet in the case of 120 unlabeled samples per class and even outperforming it while using the margin strategy.
In this paper, we extended Prototypical Networks to adapt to new classification tasks in the semi-supervised few-shot learning scenario when a few labeled examples are accompanied with many unlabeled examples from the same classes. We proposed to use the clustering approach to semi-supervised classification when the clustering process is guided by the labeled examples. This is different to recent deep learning papers (Rasmus et al., 2015; Miyato et al., 2015; Laine and Aila, 2016; Tarvainen and Valpola, 2017) which constrain the classifier using unlabeled data. We also advocated that in many real-world applications it can be possible to request the few labeled examples from the user, which can yield better performance.
The proposed solution of semi-supervised few-shot adaptation is based on doing
-means clustering in the embedding space found by Prototypical Networks. These two methods make a good fit because they make similar assumptions about the data distribution: In Prototypical Networks, the distribution of each class is represented by its mean and the variances of class distributions are assumed equal. The same assumptions are made by-means.
The fundamental bottleneck of the proposed approach in improving the classification performance is the ability of the feature extractor to cluster unseen data. Although we used an embedding network trained using Prototypical Networks, the adaptation mechanisms proposed in this paper can be performed using other feature extractors as well. A feature extractor explicitly trained to cluster data can further improve the few-shot classification performance and this is an area of active research (Song et al., 2016, 2017; Law et al., 2017). Building feature extractors that allow better generalization is largely an unsolved problem and it requires further exploration (see, e.g., Sabour et al., 2017; Hinton et al., 2018).
We would like to thank our colleagues from The Curious AI Company for fruitful discussions.
Basu et al. (2002)
Basu, S., Banerjee, A., and Mooney, R. (2002).
Semi-supervised clustering by seeding.
In Proceedings of 19th International Conference on Machine Learning (ICML-2002. Citeseer.
- Chapelle et al. (2006) Chapelle, O., Scholkopf, B., and Zien, A., editors (2006). Semi-supervised learning.
- Cohn et al. (2003) Cohn, D., Caruana, R., and McCallum, A. (2003). Semi-supervised clustering with user feedback. Constrained Clustering: Advances in Algorithms, Theory, and Applications, 4(1), 17–32.
- Finn et al. (2017a) Finn, C., Abbeel, P., and Levine, S. (2017a). Model-agnostic meta-learning for fast adaptation of deep networks. arXiv preprint arXiv:1703.03400.
- Finn et al. (2017b) Finn, C., Yu, T., Zhang, T., Abbeel, P., and Levine, S. (2017b). One-shot visual imitation learning via meta-learning. arXiv preprint arXiv:1709.04905.
- Hinton et al. (2018) Hinton, G. E., Sabour, S., and Frosst, N. (2018). Matrix capsules with EM routing. International Conference on Learning Representations.
- Kingma and Ba (2014) Kingma, D. and Ba, J. (2014). Adam: A method for stochastic optimization. arXiv preprint arXiv:1412.6980.
Koch et al. (2015)
Koch, G., Zemel, R., and Salakhutdinov, R. (2015).
Siamese neural networks for one-shot image recognition.In ICML Deep Learning Workshop, volume 2.
- Laine and Aila (2016) Laine, S. and Aila, T. (2016). Temporal ensembling for semi-supervised learning. arXiv preprint arXiv:1610.02242.
- Lake et al. (2015) Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332–1338.
Law et al. (2017)
Law, M. T., Urtasun, R., and Zemel, R. S. (2017).
Deep spectral clustering learning.In International Conference on Machine Learning, pages 1985–1994.
- Li et al. (2017) Li, Z., Zhou, F., Chen, F., and Li, H. (2017). Meta-SGD: Learning to learn quickly for few shot learning. arXiv preprint arXiv:1707.09835.
- Lloyd (1982) Lloyd, S. (1982). Least squares quantization in PCM. IEEE Transactions on Information Theory, 28(2), 129–137.
- Miyato et al. (2015) Miyato, T., Maeda, S.-i., Koyama, M., Nakae, K., and Ishii, S. (2015). Distributional smoothing with virtual adversarial training. arXiv preprint arXiv:1507.00677.
- Munkhdalai and Yu (2017) Munkhdalai, T. and Yu, H. (2017). Meta networks. arXiv preprint arXiv:1703.00837.
- Rasmus et al. (2015) Rasmus, A., Berglund, M., Honkala, M., Valpola, H., and Raiko, T. (2015). Semi-supervised learning with Ladder networks. In Advances in Neural Information Processing Systems, pages 3546–3554.
- Ravi and Larochelle (2016) Ravi, S. and Larochelle, H. (2016). Optimization as a model for few-shot learning.
- Ren et al. (2018) Ren, M., Triantafillou, E., Ravi, S., Snell, J., Swersky, K., Tenenbaum, J. B., Larochelle, H., and Zemel, R. S. (2018). Meta-learning for semi-supervised few-shot classification. In International Conference on Learning Representations.
- Sabour et al. (2017) Sabour, S., Frosst, N., and Hinton, G. E. (2017). Dynamic routing between capsules. In Advances in Neural Information Processing Systems, pages 3859–3869.
Scheffer et al. (2001)
Scheffer, T., Decomain, C., and Wrobel, S. (2001).
Active hidden Markov models for information extraction.In International Symposium on Intelligent Data Analysis, pages 309–318. Springer.
- Shyam et al. (2017) Shyam, P., Gupta, S., and Dukkipati, A. (2017). Attentive recurrent comparators. arXiv preprint arXiv:1703.00767.
- Snell et al. (2017) Snell, J., Swersky, K., and Zemel, R. S. (2017). Prototypical networks for few-shot learning. arXiv preprint arXiv:1703.05175.
- Song et al. (2016) Song, H. O., Xiang, Y., Jegelka, S., and Savarese, S. (2016). Deep metric learning via lifted structured feature embedding. In Computer Vision and Pattern Recognition (CVPR), 2016 IEEE Conference on, pages 4004–4012. IEEE.
- Song et al. (2017) Song, H. O., Jegelka, S., Rathod, V., and Murphy, K. (2017). Deep metric learning via facility location. In Computer Vision and Pattern Recognition (CVPR).
- Tarvainen and Valpola (2017) Tarvainen, A. and Valpola, H. (2017). Weight-averaged consistency targets improve semi-supervised deep learning results. arXiv preprint arXiv:1703.01780.
- Vinyals et al. (2016) Vinyals, O., Blundell, C., Lillicrap, T., Wierstra, D., et al. (2016). Matching networks for one shot learning. In Advances in Neural Information Processing Systems, pages 3630–3638.
Wagstaff et al. (2001)
Wagstaff, K., Cardie, C., Rogers, S., Schrödl, S., et al. (2001).
Constrained k-means clustering with background knowledge.In ICML, volume 1, pages 577–584.
- Zagoruyko and Komodakis (2016) Zagoruyko, S. and Komodakis, N. (2016). Wide residual networks. arXiv preprint arXiv:1605.07146.
Appendix A Sensitivity to the Number of Shots
Snell et al. (2017) showed that the classification performance of Prototypical Networks is sensitive to the number of classes per task during training and that it was necessary to match the number of samples per class during training and testing. We believe that using a larger number of classes ( way) works as a regularizer. However, varying effectively changes the batch size and therefore the learning rate needs to be adapted to . By tuning the learning rate for different , we observed a smaller effect of this regularization on the adaptation accuracy.
The dependency on implies that the embedding learned by the network does not generalize well to a different number of samples during test time. This is inconvenient especially if the number of the labeled example is unknown in advance and it can grow, for example, as a result of interaction with the user. To address this problem, we propose to use a varying number of samples per class during training for better generalization to the number of shots during test time. The results in Table 7 illustrate that this strategy is effective and reduces the sensitivity to during training.
|Training shot||Testing shot|