Centroid Networks for Few-Shot Clustering and Unsupervised Few-Shot Classification

by   Gabriel Huang, et al.

Traditional clustering algorithms such as K-means rely heavily on the nature of the chosen metric or data representation. To get meaningful clusters, these representations need to be tailored to the downstream task (e.g. cluster photos by object category, cluster faces by identity). Therefore, we frame clustering as a meta-learning task, few-shot clustering, which allows us to specify how to cluster the data at the meta-training level, despite the clustering algorithm itself being unsupervised. We propose Centroid Networks, a simple and efficient few-shot clustering method based on learning representations which are tailored both to the task to solve and to its internal clustering module. We also introduce unsupervised few-shot classification, which is conceptually similar to few-shot clustering, but is strictly harder than supervised* few-shot classification and therefore allows direct comparison with existing supervised few-shot classification methods. On Omniglot and miniImageNet, our method achieves accuracy competitive with popular supervised few-shot classification algorithms, despite using *no labels* from the support set. We also show performance competitive with state-of-the-art learning-to-cluster methods.


page 1

page 2

page 3

page 4


Meta Learning for Few-Shot One-class Classification

We propose a method that can perform one-class classification given only...

Trainable Class Prototypes for Few-Shot Learning

Metric learning is a widely used method for few shot learning in which t...

Self-Supervised Prototypical Transfer Learning for Few-Shot Classification

Most approaches in few-shot learning rely on costly annotated data relat...

Unsupervised Few-Shot Action Recognition via Action-Appearance Aligned Meta-Adaptation

We present MetaUVFS as the first Unsupervised Meta-learning algorithm fo...

Infinite Mixture Prototypes for Few-Shot Learning

We propose infinite mixture prototypes to adaptively represent both simp...

Simplex Clustering via sBeta with Applications to Online Adjustment of Black-Box Predictions

We explore clustering the softmax predictions of deep neural networks an...

Unsupervised Embedding Adaptation via Early-Stage Feature Reconstruction for Few-Shot Classification

We propose unsupervised embedding adaptation for the downstream few-shot...

1 Introduction

Clustering is generally not a well-specified problem. Consider the classic task of clustering news articles (Finley & Joachims, 2005). Does the user want articles clustered by topic, author, or language? Consider another task of clustering pictures of cars. One could group cars by color, brand/make, vehicule type, or year, but which criterion is right? Hence, the definition of the clustering criterion can be arbitrary and should be informed by the specific application or use case. For example, a natural problem with a clear use case is face clustering. Photo gallery applications often allow users to browse photos by face identity of the individuals present in the photo collection. This is not a supervised classification problem, since each user will have photos of different individuals. It is instead a clustering problem, but driven specifically by the criterion of person identity.

Therefore, we propose to frame clustering as a meta-learning problem, using the few-shot learning framework defined in Vinyals et al. (2016)

. The goal is to learn a clustering algorithm which takes as input unlabeled data and outputs a classifier which associates each data with a cluster. As motivated above, we also assume that the clustering algorithm to be learned must group data according to some ground truth notion of cluster. That is, while clustering is an unsupervised learning task, ultimately we will evaluate the quality of clusters with respect to some annotation. Thus, we use this annotation to define a meta-training task, where episodes corresponding to clustering problems are generated according to the ground-truth annotation. The meta-learning task becomes to train a clustering algorithm that minimizes a loss with respect to an ideal, ground-truth clustering of the data.

As an example, consider a dataset of images of faces, which would be required by the photo library use case described above. The annotation that would determine the clusters to be the identity of the person. Accordingly, meta-training would be based on generating episodes that each contain the images of a subset of the people in the dataset along with their target ground-truth clustering based on identity. After meta-training, generalization of the clustering algorithm would then be evaluated on episodes containing the images of different individuals, never seen by the clustering algorithm.

In this work, we consider the particularly difficult setting where clusters contain a small number of data points. In the photo gallery example, this few-shot setting would correspond to a user who has taken only a small number of pictures of each of his acquaintance. We thus refer to this problem as few-shot clustering

. We also consider the setting where, once cluster centroids have been extracted, new examples are presented and must be classified into these clusters. This problem is akin to few-shot classification, but where no labels are leveraged to estimate the classes (here clusters). Hence we refer to this problem as

unsupervised few-shot classification.

Our contributions are two-fold:

  • We propose Centroid Networks, a new method for few-shot clustering and unsupervised few-shot classification, which combines ideas from Prototypical Networks, K-means, and optimal transport. Centroid Networks are simple and only have one learnable component, the embedding function, which is trained using a supervised surrogate loss. At test time, Centroid Networks are also very fast to run: once embeddings are computed, the clustering process itself takes virtually no time.

  • We benchmark Centroid Networks on Omniglot and miniImageNet and show excellent results compared to state-of-the art methods from the learning to cluster and supervised few-shot classification literature. In particular, our method can get surprisingly high accuracies compared to the supervised few-shot classification setting, despite using no labels from the support set.

2 Related Work

Supervised clustering. Supervised clustering is defined in Finley & Joachims (2005) as “learning how to cluster future sets of items […] given sets of items and complete clusterings over these sets”. They use structured SVM to learn a similarity-metric between pairs of items, then run a fixed clustering algorithm which optimizes the sum of similarities of pairs in the same cluster. In a follow-up work (Finley & Joachims, 2008), they use K-means as the clustering. A main difference with this line of work is that we learn a nonlinear embedding function, whereas their framework assumes linear embeddings. The work Awasthi & Zadeh (2010) is also called supervised clustering, although they solve a very different problem. They propose a clustering algorithm which repetitively presents candidate clusterings to a “teacher” and actively requests feedback (supervision).

Learning to cluster.

Recent deep learning literature has preferred the term “learning to cluster” to “supervised clustering”. Although the task is still the same, the main difference is the learning of a similarity metric using deep networks. Because of this aspect, these works are often classified as falling in the “metric learning” literature.

Hsu et al. (2017, 2019)

propose a Constrained Clustering Network (CCN) for learning to cluster based on two distinct steps: learning a similarity metric to predict if two examples are in the same class, and optimizing a neural network to predict cluster assignments which tend to agree with the similarity metric. CCNs obtained the state-of-the-art results when compared against other supervised clustering algorithms, we will thus use CCN as a strong baseline. In our experiments, Centroid Networks improve over CCN on their benchmarks, while being simpler to train and computationally much cheaper.

Semi-supervised & constrained clustering. Semi-supervised clustering consists of clustering data with some supervision in the form of “this pair of points should be/not be in the same cluster”. Some methods take the pairwise supervision as hard constraints (Wagstaff et al., 2001), while others (including CCN) learn metrics which tend to satisfy those constraints (Bilenko et al., 2004). See the related work sections in Finley & Joachims (2005); Hsu et al. (2017).

Supervised few-shot classification. For the unsupervised few-shot classification task, our method may be compared to the supervised few-shot classification literature (Vinyals et al., 2016; Ravi & Larochelle, 2017; Finn et al., 2017). In particular, we have compared with Prototypical Networks (Snell et al., 2017), which was a source of inspiration for Centroid Networks. Our work is also related to follow-up work on Semi-Supervised Prototypical Networks (Ren et al., 2018), in which the support set contains both labeled and unlabeled examples. In this work, we go beyond by requiring no labels to infer centroids at evaluation time.

Sinkhorn K-means. The idea of formulating clustering as minimizing a Wasserstein distance between empirical distributions has been proposed several times in the past (Mi et al., 2018a). Canas & Rosasco (2012) explicit some theoretical links between K-means and the Wasserstein-2 distance. The most similar work to Sinkhorn K-means is Regularized Wasserstein-Means (Mi et al., 2018b), but they use another method for solving optimal transport. Specifically using Sinkhorn distances111A regularized version of the Wasserstein distance. for clustering has even been suggested in Genevay et al. (2018). However, as we could not find an explicit description of the Sinkhorn K-means anywhere in the literature, we coin the name and explicitly state the algorithm in Section 4. To our knowledge, we are the first to use Sinkhorn K-means in the context of learning to cluster and to scale it up to more complex datasets like miniImageNet. Note that our work should not be confused with Wasserstein K-means and similar variants, which consist in replacing the squared base-distance in K-means with a Wasserstein distance.

Meta-Learning and Unsupervised Learning Finally, some recent work has explored combinations of unsupervised learning and meta-learning, to address various other tasks. Metz et al. (2018) propose a method to meta-train an unsupervised representation learning model that produces useful features for some given task. That is, at evaluation time, their method produces features without requiring labels, much like Centroid Networks produce centroids without requiring labels. The difference with their method thus lies in the addressed task: we focus on clustering, while they consider the task of representation/feature learning. Hsu et al. (2018); Khodadadeh et al. (2018) also considers the opposite: meta-learning that requires no labels for meta-training but that delivers methods that require labels to be run at evaluation time. Specifically, they propose unsupervised approaches to generate episodes for supervised few-shot classification, while we use supervised data to learn an unsupervised clustering algorithm.

3 Background

We first review supervised few-shot classification, usually just called few-shot learning, before introducing the few-shot clustering and unsupervised few-shot classification tasks. Following standard practice for few-shot learning, we frame both as meta-learning problems.

3.1 Notation and Terminology

In each task, we denote the number of classes (ways), / the number of training/validation examples per class (shots/queries) and / the total number of training/validation examples.222We follow the usual machine learning convention ( classes), different than in the few-shot learning literature ( examples).

Meta-learning. Meta-learning, which literally means learning-to-learn, can actually refer to a lot of things, and the terminology can be quite confusing, so we clarify the meaning here. In our case, we call meta-learning the problem of learning to solve tasks. That is, given a distribution over a set of tasks (technically called base-tasks), we want to learn a general method for solving tasks from .333Tasks are to meta-learning what data are to regular machine learning. The performance on each individual task is computed using a base-evaluation metric. Typically, we have access to three disjoint sets of tasks, supposedly sampled from . Similarly to regular machine learning, we meta-train the method to solve tasks on the meta-training set

, tune its hyperparameters on tasks from the

meta-validation set, and evaluate its performance on the meta-testing set

. End-to-end methods are trained by directly minimizing the base-evaluation metric (or a subdifferentiable surrogate), i.e., the

meta-training loss is the base-evaluation metric.

3.2 Supervised Few-shot Classification

Supervised few-shot classification consists in learning a classifier from a small number of examples. In practice, this is formulated as K-way M-shot learning, where and are small (usually and ). Each base-task comes with two disjoint sets of data, which are typically images:

  • Labeled support set containing K classes and M training examples per class (base-training set).

  • Unlabeled query set also containing K classes but with M’ examples examples per class (base-validation set).

The base-task is to train a classifier on the support set and make predictions on the query set. For base-evaluation, the query set labels are revealed and the accuracy on the query set is computed. We call this the supervised accuracy. End-to-end methods are meta-trained by directly minimizing the base-evaluation metric (more precisely the cross-entropy on the query set, which is a surrogate).

3.3 Few-shot Clustering

Few-shot clustering consists in learning to cluster a small number of examples. In this work, we formulate it as K-way M-shot clustering, where and are small (usually ). Each base-task comes with a single set of data, which are also usually images:

  • Unlabeled support set containing K classes and M training examples per class (we don’t know which classes are present nor which class each image belongs to).

Unlike the supervised case, there is no longer a support/query set distinction but we will call the support set, which will later be convenient. The base-task is to partition the set into disjoint clusters. We denote the predicted cluster index for point . For base-evaluation, the ground-truth labels are revealed, and the optimal one-to-one matching which maximizes the accuracy is computed. Finally, the accuracy is computed after permuting the cluster indices:

We call this the clustering accuracy following the learning to cluster literature (Hsu et al., 2017). This metric will be used to compare our method with the learning to cluster literature.

3.4 Unsupervised Few-shot Classification

We introduce a new task, unsupervised few-shot classification, in order to compare our method with the supervised few-shot classification literature. Indeed, it is not fair to compare the clustering accuracy directly with the supervised accuracy. In few-shot clustering, we are given the whole set 444Just in this paragraph, think of as the query set because it is the one we evaluate on. and we can predict clusters jointly, whereas in supervised few-shot classification, the predictions on the query set should be made independently and should not depend on the size of the set. In particular, 1-shot clustering is trivial because each point is always in a different cluster, while 1-shot classification is an actual task.

Therefore we introduce unsupervised few-shot classification, which is a strictly harder task than supervised few-shot classification, but has a very similar structure. Each base-task comes with two disjoint sets of data:

  • Unlabeled support set containing K classes and M training examples per class (base-training set).

  • Unlabeled query set containing K classes and M’ examples examples per class (base-validation set).

The base-task is to cluster the support set into clusters, then to map each point in the query set to one of the predicted clusters. Denote and the predicted cluster indices on and . For base-evaluation, the ground-truth labels of the support set are revealed, and the optimal one-to-one matching which maximizes the accuracy on the support set is computed. Finally, we permute the predicted indices of the query set, reveal ground-truth query set labels and compute the accuracy:

We call this the unsupervised accuracy. We emphasize that this task is strictly harder555Except for the 1-shot case, where unsupervised few-shot classification reduces to supervised few-shot classification, since can be perfectly clustered. than supervised few-shot classification, because the correct labeling of the support set will be consistently given only if we cluster the support set correctly. Therefore, unsupervised accuracy can be used to compare our method with the supervised few-shot classification literature.

3.5 Sinkhorn Distances

The Wasserstein-2

distance is a distance between two probability masses

and . Given a base distance , we define the cost of transporting one unit of mass from to as . The Wasserstein-2 distance is defined as the cheapest cost for transporting all mass from to . When the transportation plan is regularized to have large entropy, we obtain Sinkhorn distances, which can be computed very efficiently for discrete distributions (Cuturi, 2013; Cuturi & Doucet, 2014) (entropy-regularization makes the problem strongly convex). Sinkhorn distances are the basis of the Sinkhorn K-means algorithm, which is the main component of Centroid Networks.

3.6 Prototypical Networks

Prototypical Networks (Snell et al., 2017) are one of the simplest and most accurate supervised few-shot classification methods. Our method, Centroid Networks, is heavily inspired by them. The only learnable component of Protonets is the embedding function which maps images to an embedding (feature) space. Given a supervised task to solve, Protonets compute the average (the prototype) of each class on the support set. Each point from the query set is then classified according to the softmax of its squared distance . Protonets are trained end-to-end by minimizing the log-loss on the query set.

4 Sinkhorn K-Means

Sinkhorn K-Means is the main clustering component of our Centroid Networks. Just like K-Means, Sinkhorn K-Means is a clustering algorithm which takes as input a set of points and outputs a set of centroids which best fit the data. Points are then clustered based on their distances to the centroids.

Conceptually, Sinkhorn K-Means attempts to iteratively minimize the Sinkhorn distance between the empirical distributions respectively defined by the data and the centroids . The idea of formulating clustering as the minimization of some Wasserstein distance between empirical distributions is not new. For instance, Section 4.1 of Genevay et al. (2018) mentions the possibility of using Sinkhorn distances for clustering. More generally, some theoretical links between K-means and the Wasserstein-2 distance are also explored in Canas & Rosasco (2012). However, to the best of our knowledge, we are the first to use Sinkhorn K-means in the context of learning to cluster and to scale it up to more complex datasets like miniImageNet.

We present our version of the Sinkhorn K-Means optimization problem, and compare it with regular K-Means. Both of them can be formulated as a joint minimization in the centroids

(real vectors) and the assignments

(scalars) which specify how much of each point is assigned to centroid :

  • K-Means. Note that compared to the usual convention, we have normalized assignments so that they sum up to 1.

  • Sinkhorn K-Means.

    where is the entropy of the assignments, and is a parameter tuning the entropy penalty term.

Sinkhorn vs. Regular K-means. The first difference is that K-means only allows hard assignments , that is, each point is assigned to exactly one cluster . On the contrary, the Sinkhorn K-means formulation allows soft assignments , but with the additional constraint that the clusters have to be balanced, i.e., the same amount of points are soft-assigned to each cluster . The second difference is the penalty term which encourages solutions of high-entropy, i.e., points will tend to be assigned more uniformly over clusters, and clusters more uniformly over points. Adding entropy-regularization allows us to compute very efficiently using the work of Cuturi (2013), as explained in the next paragraph. Beyond computational reasons, we will see in the next section that the entropy-term is necessary in order to calibrate the meta-training and meta-validation phases of Centroid Networks. Note that removing the balancing constraint in the Sinkhorn K-Means objective would yield a regularized K-means objective with coordinate update steps identical to EM in a mixture of Gaussians (with updated using softmax conditionals).

Algorithms. Both K-means and Sinkhorn K-means can be solved iteratively by alternating coordinate descent on the assignments and centroids. Minimization in the assignments is an argmin in K-means, and a call to the Sinkhorn algorithm666In practice we use the log-sum-exp trick in the Sinkhorn algorithm to avoid numerical underflows. for Sinkhorn K-means (see Algorithm 1). Minimization in the centroids amounts to setting them equal to the weighted average of the points assigned to them. The full Sinkhorn K-means procedure777We initialize centroids around zero and add a tiny bit of Gaussian noise to break symmetries. All details in code. is described in Algorithm 2.

  Input: data , centroids , regularization constant .
  Output: optimal transport plan .
  Compute pairwise squared L2-distances, scale by and exponentiate:
  Initialize dual variables : .
  Initialize row and column marginal dist. to uniform: .
  while not converged do
      // Enforce row marginals
      // Enforce column marginals
  end while
  Return optimal transport plan ):
Algorithm 1 for Wasserstein-2 distance
  Input: data , initial centroids , regularization constant .
  Output: final centroids , optimal assignment .
  while not converged do
      // Expectation : Update .
     Compute optimal transport between data and centroids:
      // Maximization : Update .
     Update centroids to minimize distance to assigned points:
  end while
  Return centroids and assignments .
Algorithm 2

5 Centroid Networks

In this section, we describe our method and explain how it can be applied to few-shot clustering and unsupervised few-shot classification. Both tasks require to cluster a set of points (either the support set or query set) as part of base-training, and only the base-validation step differs. Centroid Networks consist of two modules: a trainable embedding module and a fixed clustering module.

Embedding module. The embedding module is directly inspired by Prototypical Networks (Snell et al., 2017), and consists of a neural network which maps data (images) to features in the embedding space. As in Protonets, the only trainable neural network of Centroid Networks is the embedding function, which makes implementation very straightforward.

Clustering module. The clustering module takes as input the embedded data and outputs a set of centroids (representatives of each cluster) as well as the (soft) assignment of each point to each centroid. We use the Sinkhorn K-means algorithm as our clustering module, because it is experimentally very stable (unlike standard K-means++), and conveniently incorporates the constraint that the predicted clusters should be balanced.

We start by assuming that the embedder is already trained, and describe how to solve few-shot clustering and unsupervised few-shot classification (Section 5.1). Then we will explain how to we train the embedding module using a surrogate loss (Section 5.2).

5.1 Prediction and Evaluation

Recall from Sections 3.3 and 3.4 that both few-shot clustering and unsupervised few-shot classification require to cluster the support set first. Denote the predicted cluster indices.

Clustering the support set. First, we embed the raw data . Then, we cluster the embeddings using Sinkhorn K-means, which returns centroids and soft assignments between data and centroids. We propose two ways of predicting clusters:

  • Softmax conditional. Similarly to Prototypical Networks, the conditional probability is given by a softmax on the distance between the point and the centroids

    We add an extra temperature parameter . Larger temperatures yield more uniform assignments.

  • Sinkhorn conditional. We set the conditional probabilities equal to the optimal transport plan between the embedded points and the centroids.

    Although there is no temperature parameter to tune, the Sinkhorn algorithm has a regularization parameter , which has a similar effect as the temperature, since using both are equivalent to rescaling the distance matrix .

Using the Sinkhorn conditional is guaranteed to return balanced clusters, but using the Softmax conditional provides no such guarantees. On the other hand, once the centroids are known, the softmax conditional can yield independent predictions, while the Sinkhorn conditional needs access to the full set and makes joint predictions. This distinction will be crucial in the evaluation of the two settings (few-shot clustering vs. unsupervised few-shot classification).

Evaluation. Once the support set is clustered, we follow the evaluation procedures described in Sections 3.3 and 3.4. For both tasks, we find the optimal one-on-one matching  between ground-truth classes and predicted clusters. For few-shot clustering, we simply return the clustering accuracy. For unsupervised few-shot classification, we make independent predictions on the query set using the Softmax conditional888The Sinkhorn conditional cannot be used here because we are not allowed to make join predictions on the query set.

Finally, we permute the cluster indices using and return the unsupervised accuracy.

5.2 Training with a Supervised Surrogate Loss

The most intuitive way to train Centroid Networks would be to train them end-to-end, by backpropagating through Sinkhorn K-means, which contains two nested loops. Although this is technically possible after defining smoother versions of the clustering/unsupervised accuracies (by replacing the 0-1 loss with a cross-entropy), we did not have much success with this approach.

Instead, we opt for the much simpler approach of training with a supervised surrogate loss. Since we have access to the ground-truth classes during meta-training, we can simply replace the centroids with the average of each class . Then, we classify the query set points using either Softmax or Sinkhorn conditionals. Finally, we compute the log-loss on the query set and minimize it using gradient descent.999In Softmax mode, surrogate loss minimization reduces to the standard training of Prototypical Networks. In fact, the reason for naming our method Centroid Networks is because they can be seen as replacing the Prototypes (class averages) by the Centroids (weighted cluster averages) during prediction. However, training with the Sinkhorn conditional is different and we show in Section 6 that it is always beneficial for the two tasks we solve. The supervised surrogate loss is very simple, as it removes both the need to find the optimal cluster-class permutation and the the need to backpropagate through Sinkhorn K-means.

Center Loss. Additionally to the supervised surrogate loss, we use a center loss penalty (Wen et al., 2016)

. Center losses have been used in metric-learning methods to penalize the variance of each class in embedding space. See for instance 

Wen et al. (2016) where it is used in addition to the standard log-loss for learning discriminative face embeddings. Using a center loss makes sense because there is no obvious reason why the surrogate loss (basically a cross-entropy) by itself would make the classes more compact in embedding space. However, compact clusters is an implicit assumption of K-means and Sinkhorn K-means, which makes it essential for having good validation performance.

6 Experiments

We benchmark Centroid Networks against state-of-the art methods on the few-shot clustering and unsupervised few-shot classification tasks. We train Centroid Networks by minimizing the surrogate loss with Sinkhorn conditionals combined with a center loss (both tricks improve the accuracies). Our method requires little to no tuning across datasets, and to show this, we run all experiments with the following default hyperparameters: temperature=1, sinkhorn regularization=1, center loss=1, and sinkhorn conditionals. The only exception is for Omniglot-CCN, where we take center loss=

. All reported accuracies for Centroid Networks are averaged over 1000 tasks (episodes) and 95% confidence intervals are given.

For unsupervised few-shot classification, we compare Centroid Networks with Prototypical Networks (Snell et al., 2017), a state-of-the-art supervised few-shot classification method, on the Omniglot and miniImageNet datasets. Data splits and architectures are the same as in Protonets, and can be found in Section A of the Appendix. We directly compare unsupervised and supervised accuracies, which is a fair comparison (Section 3.4). The results101010We exclude the 1-shot setting from our experiments because it is trivial in our case. Indeed, for few-shot clustering, Centroid Networks would consistently assign exactly one point per cluster and obtain 100% accuracy. For unsupervised few-shot classification, the centroids would be equal to the prototypes up to permutation, and Centroid Networks would reduce to Prototypical network (for evaluation). on Omniglot are given in Table 1, and the results on miniImageNet are given in Table 2. For Omniglot, Centroid Networks achieves nearly the same accuracy as Prototypical Networks despite using none of the labels of the support set. For miniImageNet, Centroid Networks can still achieve a decent accuracy of 53.1%, which is of the same order as the reference accuracy of 66.9%, despite not using any labels from the support set.

For few-shot clustering, we compare with Constrained Clustering Networks (Hsu et al., 2017, 2019), a recent state-of-the art learning to cluster method, on the same task as them, which we will denote Omniglot-CCN.111111We make the choice to apply our method on their task rather than the opposite because their method is much slower and more complicated to run. By solving the same task as them, we can compare directly with the results from their paper. Omniglot is resized to and split into 30 alphabets for training (background set) and 20 alphabets for evaluation. The Omniglot-CCN task consists in clustering each alphabet of the evaluation set individually, after training on the background set. This makes it a harder task than standard few-shot classification on Omniglot, because characters from the same alphabet are harder to separate, and because the number of ways varies from 20 to 47 characters per set. We run Centroid Networks with all default hyperparameters, except a centroid loss of . The results given in Table 3 show that Centroid Networks outperform all “flavors” of CCN by a margin (86.8% vs. 83.3% highest). Furthermore, Centroid Networks are also simpler and about 100 times faster than CCN, because they require to embed the data only once, instead of iteratively minimizing a KCL/MCL criterion. However, we must recognize that the CCN approach has greater flexibility because it can deal with a variable number of clusters which are not necessarily balanced.

Omniglot 5-way 5-shot Acc. 20-way 5-shot Acc.

Few-shot Clustering (clustering accuracy)

K-Means (raw images)
K-Means (Protonet features)
Centroid Networks (ours)

Unsupervised Few-shot Class. (unsupervised acc.)

Centroid Networks (ours)

Reference Oracle: Supervised Few-shot Class. (supervised acc.)
Prototypical Networks * *
(Snell et al., 2017)
Table 1: Test accuracies on Omniglot evaluation set. Clustering accuracies are given for few-shot clustering methods, unsupervised accuracies for unsupervised few-shot classification, and supervised accuracy for the reference oracle method (prototypical networks). Numbers with a star* are those reported in (Snell et al., 2017) and match our implementation. All our accuracy results are averaged over 1000 test episodes with a fixed model, and are reported with 95% confidence intervals.
miniImageNet 5-way 5-shot Acc.

Few-shot Clustering (clustering accuracy)

K-Means (Protonet features)
Centroid Networks (ours, best setting)

Unsupervised Few-shot Classification (unsupervised acc.)

Centroid Networks (ours, best setting)

Reference Oracle: Supervised Few-shot Class. (supervised acc.)
Prototypical Networks (Snell et al., 2017)
Table 2: Test accuracies on miniImageNet evaluation set. All our accuracy results are averaged over 1000 test episodes with a fixed model, and are reported with 95% confidence intervals, except prototypical networks, for which we report the results of a third party implementation, since no code for miniImageNet was provided with original paper.
Omniglot (CCN setting) -way 20-shot Acc.

Few-shot Clustering (clustering accuracy)
K-Means (raw features) *
CCN (KCL) (Hsu et al., 2017) *
CCN (MCL) (Hsu et al., 2019) *
Centroid Networks (ours, protonet arch.)
Centroid Networks (ours, CCN arch.)

Table 3: Test clustering accuracies on Omniglot evaluation set, using the Constrained Clustering Network splits (Hsu et al., 2017). Numbers with a star* are those reported in (Hsu et al., 2019). We compared both using the Protonet architecture and the architecture in (Hsu et al., 2017) (CCN), which has more filters. The differences between the two architectures are not significant. All our accuracy results are averaged over 1000 test episodes with a fixed model, and are reported with 95% confidence intervals.

Moreover, for few-shot clustering, we have also compared Centroid Networks with two baselines on Omniglot and miniImageNet (standard splits, see Appendix A). We have run K-means with K-means++ initialization directly on the raw images and shown that it performs very poorly even on Omniglot (Table 1), which confirms the importance of learning an embedding function. We have also run K-means on pretrained Protonet features, which is a more interesting comparison, since at the highest level, our method could be described as just clustering Protonet embeddings (Tables 1 and 2). It turns out that Centroid Networks still outperform K-means on the embeddings by a substantial margin on both Omniglot (99.6% vs. 83.5% for 5-way) and miniImageNet (62.6% vs. 48.7%), which confirms the importance of combining Sinkhorn conditionals, Sinkhorn K-means, and the center loss trick. Interestingly, on Omniglot 20-way 5-shot, the clustering accuracy of Centroid Networks is actually a bit higher than the supervised accuracy of Protonets ( vs. ) despite using no labels from the support set. Although impressive, this result is not paradoxical and only confirms that clustering accuracies cannot be directly compared with supervised accuracies (Section 3.4).

7 Conclusion

We have proposed Centroid Networks for solving few-shot clustering and unsupervised few-shot classification. Centroid Networks can be easily trained by minimizing a simple supervised surrogate loss. Moreover, like in Prototypical Networks, the only learnable component is the embedding function. Predicting clusters is also extremely fast, as the Sinkhorn K-means algorithm takes virtually no time once the embeddings are computed. This is in stark contrast to recent state-of-art methods like the Constrained Clustering Network approach which require training a neural network again on each clustering task. We have benchmarked Centroid Networks on Omniglot and miniImageNet and compared them against state-of-the-art methods from the learning to cluster and supervised few-shot classification literature. We have shown that Centroid Networks perform surprisingly well compared to Prototypical Networks, despite using none of the labels of the support set. The fact that we do not really need the support set labels to get decent accuracies on Omniglot and miniImageNet motivates recent work (Triantafillou et al., 2018) on developing harder meta-learning benchmarks.


We thank Min Lin and Eugene Belilovsky for insightful discussions on few-shot classification. We thank Jose Gallego for insightful discussions on Sinkhorn K-Means. This research was partially supported by the NSERC Discovery Grant RGPIN2017-06936, a Google Focused Research Award and the Canada CIFAR AI Chair Program. We also thank Google for providing Google Cloud credits to help with the experiments.

Appendix A Data splits and architecture for the experiments

For the embedding network in all our experiments, we reuse exactly the same simple convolutional architecture as in Prototypical Networks (Snell et al., 2017), which consists of four stacked blocks (2D convolution with

kernel and stride 1, BatchNorm, ReLU, and

max-pooling), the output of which is flattened. This results in a 64-dimensional embedding for Omniglot and 1600-dimensional embedding for miniImageNet. For miniImageNet, we pretrain the embedding function using prototypical networks to solve 30-way problems instead of 5, which is the recommended trick in the paper (Snell et al., 2017). For the other settings, we train from scratch.

Omniglot (Lake et al., 2011) consists of a total of 1623 classes of handwritten characters from 50 alphabets, with 20 examples per class. Images are grayscale with size . We follow the same protocol as in Prototypical Networks and use the “Vinyals” train/validation/test splits. We consider 5-way 5-shot and 20-way 5-shot settings (15 query points per class).

miniImageNet (Vinyals et al., 2016) consists of 100 classes, each containing 600 color images of size . We follow the “Ravi” splits: 64 classes for training, 16 for validation, and 20 for testing. We consider the 5-way 5-shot setting (15 query points per class).