Building One-Shot Semi-supervised (BOSS) Learning up to Fully Supervised Performance

06/16/2020 ∙ by Leslie N. Smith, et al. ∙ U.S. Navy 0

Reaching the performance of fully supervised learning with unlabeled data and only labeling one sample per class might be ideal for deep learning applications. We demonstrate for the first time the potential for building one-shot semi-supervised (BOSS) learning on Cifar-10 and SVHN up to attain test accuracies that are comparable to fully supervised learning. Our method combines class prototype refining, class balancing, and self-training. A good prototype choice is essential and we propose a practical technique for obtaining iconic examples. In addition, we demonstrate that class balancing methods substantially improve accuracy results in semi-supervised learning to levels that allow self-training to reach the level of fully supervised learning performance. Rigorous empirical evaluations provide evidence that labeling large datasets is not necessary for training deep neural networks. We made our code available at <https://github.com/lnsmith54/BOSS> to facilitate replication and for use with future real-world applications.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

In recent years deep learning has achieved state-of-the-art performance for computer vision tasks such as image classification. However, a major barrier to the wider-spread adoption of deep neural networks for new applications is that training state-of-the-art deep networks typically requires hundreds or thousands of labeled samples per class to perform at high levels of accuracy and to generalize well.

Unfortunately, manual labeling is labor intensive and might not be practical if labeling the data requires specialized expertise, such as in medical, defense, and scientific applications. In typical real-world scenarios for deep learning, one often has access to large amounts of unlabeled data but lacks the time or expertise to label the required massive numbers needed for training, validation, and testing. An ideal solution might be to achieve performance levels that are equivalent to fully supervised trained networks with only one manually labeled image per class.

In this paper we investigate the potential for building one-shot semi-supervised (BOSS) learning up to achieve comparable performance as fully supervised training. To date, one-shot semi-supervised learning has been little studied and viewed as difficult. We build on the recent observation that one-shot semi-supervised learning is plagued by class imbalance problems Smith and Conovaloff (2020). In our context, class imbalance refers to a trained network with near 100% accuracy on a subset of classes and has poor performance on other classes. We demonstrate that good prototypes are crucial for successful semi-supervised learning and propose a practical prototype replacement method for the poorly performing classes. Also, we make use of the state-of-the-art in semi-supervised learning methods (i.e., FixMatch Sohn et al. (2020)) in our experiments. To combat class imbalance, we tested several variations of methods found in the literature for class imbalance problems Johnson and Khoshgoftaar (2019), which refers to the situation where the number of training samples per class vary substantially. We are the first to demonstrate that these methods significantly boost the performance of one-shot semi-supervised learning. Combining these methods with self-training Rosenberg et al. (2005) makes it possible on Cifar-10 and SVHN to attain comparable performance as fully supervised trained deep networks.

Unfortunately, we also observed that one-shot semi-supervised learning is more sensitive to hyper-parameters tuning than fully supervised training, which makes training a delicate affair. While this sensitivity can be challenging in practice, we note that this sensitivity can also lead to new opportunities. For example, often researchers propose new network architectures, loss functions, and optimization functions that are tested in the fully supervised regime where small performance gains are used to claim a new state-of-the-art. If these algorithms were instead tested in one-shot semi-supervised learning, more substantial differences in performance would better differentiate methods. Along these lines, we also advocate the use of one-shot semi-supervised learning with AutoML and neural architecture search (NAS)

Elsken et al. (2018) to find optimal hyper-parameters and architectures.

Our contributions are:

  1. We rigorously demonstrate for the first time the potential for one-shot semi-supervised learning to reach test accuracies with Cifar-10 and SVHN that are comparable to fully supervised learning.

  2. We investigate the value of class balancing for one-shot semi-supervised learning. We introduce four class balancing methods for semi-supervised learning that improve the performance of semi-supervised learning.

  3. We propose a practical method for finding iconic prototypes for each class and show that refining a few class prototypes can substantially improve performance.

2 Related Work

Semi-supervised learning:

Semi-supervised learning is a hybrid between supervised and unsupervised learning, which combines the benefits of both to better match the scenario of real-world problems. As with supervised learning, semi-supervised learning defines a task (i.e., classification) from labeled data but typically it uses many fewer labeled samples. Like unsupervised learning, semi-supervised learning leverages feature learning from unlabeled data to the greatest extent possible. Semi-supervised learning is a large and mature field and there are several surveys and books on semi-supervised learning methods

Zhu (2005); Van Engelen and Hoos (2020); Chapelle et al. (2009); Zhu and Goldberg (2009) for the interested reader. In this Section we mention only the most relevant of recent methods.

Recently there have been a series of papers on semi-supervised learning from Google Reseach, including MixMatch Berthelot et al. (2019b) , ReMixMatch Berthelot et al. (2019a), and FixMatch Sohn et al. (2020). MixMatch combines consistency regularization with data augmentation Sajjadi et al. (2016), entropy minimization (i.e., sharpening) Grandvalet and Bengio (2005), and mixup Zhang et al. (2017). ReMixMatch improved on MixMatch by incorporating distribution alignment and augmentation anchors. Augmentation anchors are similar to pseudo-labeling. FixMatch is the most recent and demonstrated state-of-the-art semi-supervised learning performance. In addition, the FixMatch paper has a discussion on one-shot semi-supervised learning with Cifar-10.

The FixMatch algorithm Sohn et al. (2020) is primarily a combination of consistency regularization Sajjadi et al. (2016); Zhai et al. (2019) and pseudo-labeling Lee (2013). Consistency regularization utilizes unlabeled data by relying on the assumption that the model should output the same predictions when fed perturbed versions as on the original image. Consistency regularization has recently become a popular technique in unsupervised, self-supervised, and semi-supervised learning Van Engelen and Hoos (2020); Zhai et al. (2019). Several researchers have observed that strong data augmentation should not be used when infering pseudo-labels for the unlabeled data but should be employed for consistency regularization Sohn et al. (2020); Xie et al. (2019)

. Pseudo-labeling is based on the idea that one can use the model to obtain artificial labels for unlabeled data by retaining pseudo-labels for samples whose probability are above a predefined threshold.

A recent survey of semi-supervised learning Van Engelen and Hoos (2020) provides a taxonomy of classification algorithms. One of the methods in semi-supervised learning is self-training iterations Triguero et al. (2015); Rosenberg et al. (2005)

where a classifier is iteratively trained on labeled data plus high confidence pseudo labeled data from previous iterations. In our experiments we found that self-training became reliable once the model’s performance is enhanced by prototype refining and class balancing.

Class imbalance: Smith and Conovaloff Smith and Conovaloff (2020) demonstrated that in one-shot semi-supervised learning there are large variation in class performances, with some classes acheiving near 100% test accuracies while other classes near 0% accuracies. That is, strong classes starve the weak classes, which is analogous to the class imbalance problem Johnson and Khoshgoftaar (2019). This observation suggests an opportunity to improve the overall performance by actively improving the performance of the weak classes.

We borrowed techniques from the literature on training with imbalanced data Johnson and Khoshgoftaar (2019); Wang and Yao (2012); Sun et al. (2007) (i.e., some classes having many more training samples than other classes) to experiment with several methods for improving the performance of the weak classes. Our experiments demonstrate that these methods substantially improve performance, even when there are the same number of labeled and unlabeled samples for each class. Methods for handling class imbalance can be grouped into two categories: data-level and algorithm-level methods. Data-level techniques Wang and Yao (2012) reduce the level of imbalance by undersampling the majority classes and oversampling the minority classes. Algorithm-level techniques Sun et al. (2007) are commonly implemented with smaller loss factor weights for the training samples belonging to the majority classes and larger weights for the training samples belonging to the minority classes. In our experiments we tested variations of both types of methods and a hybrid of the two.

Class imbalanced semi supervised learning Hyun et al. (2020) is related to our work but Hyun, et al. addressed the problem space where labeled training data is many more plentiful and the number of both labeled and unlabeled data in each class are vary substantially. Hyun, et al. propose a weighting scheme to under-weight the minority class contribution to the unlabeled loss function, while we instead reduce the weight of the majority classes to the unlabeled loss function, which is more consistent with the class imbalance literature. Li, et al. Li et al. (2019) propose combining self-training with semi-supervised learning for few-shot classification but unlike our method, their method employs a supervised few-shot method for pseudo-labeling.

Meta-learning: Our scenario superficially bears similarity to few-shot meta learning Koch et al. (2015); Vinyals et al. (2016); Finn et al. (2017); Snell et al. (2017), which is a highly active area of research. The majority of the work in this area relies on a large labeled dataset with similar data statistics but this can be an onerous requirement for new applications. While there are some recent efforts in unsupervised pretraining for few-shot meta learning Hsu et al. (2018); Antoniou and Storkey (2019), our experiments with these methods demonstrated their inability to adequately perform in one-shot learning to bootstrap our process. Specifically, unsupervised one-shot learning with only five classes obtained a test accuracy of about 50% on high confidence samples and the accuracy dropped sharply when increasing the number of classes.

3 BOSS Methodology

3.1 FixMatch

Since we build on FixMatch Sohn et al. (2020), we briefly describe the algorithm and adopt the formalism used in the original paper. For an N-class classification problem, let us define as a batch of B labeled examples, where are the training examples and are its labels. We also define as a batch of unlabeled examples where and

is a hyperparameter that determines the ratio of

to . Let be the predicted class distribution produced by the model for input

. We denote the cross-entropy between two probability distributions

and as .

The loss function for FixMatch consists two terms: a supervised loss applied to labeled data and an unsupervised loss for the unlabeled data. is the cross-entropy loss on weakly augmented labeled examples:

(1)

where represent weak data augmentation on sample .

For the unsupervised loss, the algorithm computes the label based on weakly augmented versions of the image as . It is essential that the label is computed on weakly augmented versions of the unlabeled training samples and not on strongly augmented versions. The pseudo-label is computed as and the unlabeled loss is given as:

(2)

where represents applying strong augmentation to sample and is a scalar confidence threshold. The total loss is given by where is a scalar hyper-parameter. Additional details on the FixMatch algorithm are available in the paper Sohn et al. (2020).

3.2 Prototype refining

Previous work by Sohn, et al. Sohn et al. (2020) on one-shot semi-supervised learning relied on the dataset labels to randomly choose an example for each class. The authors demonstrated that the choice of these samples significantly affected the performance of their algorithm. Specifically, they ordered the CIFAR-10 training data by how representative they were of their class by utilizing fully supervised trained models and found that using more prototypical examples achieved a median accuracy of 78% while the use of poorly representative samples failed to converge at all. The authors acknowledged that their method for finding prototypes was not practical. In contrast, we now present a practical approach for choosing an iconic prototype for each class.

In real-world scenarios, one’s data is initially all unlabeled but it is not overly burdensome for an expert to manually sift through some of their dataset to find one iconic example of each class. In choosing iconic images of each class, the labeler’s goal is to pick images that represent the class objects well, while minimizing the amount of background distractors in the image. In our own experiments with labeled datasets Cifar-10 and SVHN, we did not rely on the labels but reviewed a small fraction of the training data to manually choose class prototypes.

In addition, we also propose a simple iterative technique for improving the choice of prototypes because good prototypes are crucial to good performance. After choosing prototypes, the next step is to make a training run and examine the final class accuracies. For any class with poor accuracies relative to the other classes, it is likely that a better prototype can be chosen. We recommend returning to the unlabeled dataset to find replacement prototypes for only the poorly performing classes. In our experiments we found doing this even once to be beneficial. In addition, our future plans include investigating potential performance improvements by preprocessing prototype images to minimize background distractors.

One might argue that prototype refining is as much work as labeling several examples per class and using many training samples will make it easier to train the model. From only a practical perspective, labeling five or ten examples per class is not substantially more effort relative to labeling only one iconic example per class and prototype refining. While in practice one may want to start with more than one example for ease of training, there are scientific, educational, and algorithmic benefits to studying one-shot semi-supervised learning, which we discuss in our Broader Impact statement.

3.3 Class balancing

We believe a class imbalance problem is an important factor in training neural networks, not only in one-shot semi-supervised learning but also a factor for small to mid-sized datasets. A network with random weights usually outputs a single class label for every sample (i.e., randomly initialized networks do not generate random predictions). Hence, all networks start their training with elements of the class imbalance problem but the presence of large, balanced training data allows the network to overcome this problem. Since class imbalance is always present when training deep networks, class balancing methods might always be valuable, particularly when training on one-shot, few-shot, or small labeled datasets, and we leave further investigations of this for future work.

For class balancing, our algorithm first computes the number of pseudo-labels generated in each class and uses this measure as a surrogate for model’s class imbalance. Specifically, as the algorithm computes the pseudo-labels for all of the unlabeled training samples, it counts the number that fall within each class, which we designate as where is the number of classes. We assume a similar number of unlabeled samples in each class so the number of pseudo-labels in each class should also be similar.

Our first class balancing method is based on oversampling minority classes. Our algorithm reduces the pseudo-labeling thresholds for minority classes to include more examples of the minority classes in the training. Formally, in pseudo-labeling the following unsupervised loss function is used for the unlabeled data in place of Equation 2:

(3)

where , , and is the class dependent threshold for inclusion in the unlabeled loss . We define the class dependent thresholds as:

(4)

where is the number of pseudo-labeled in class and is a scalar hyper-parameter () guiding how much to lower the threshold for minority classes. Hence, the most frequent class will use a threshold of while minority classes will use lower thresholds, down to .

The next two class balancing methods are variations on loss function class weightings. In the FixMatch algorithm, all unlabeled samples above the threshold are included in Equation 3 with the same weight. Instead, our second class balancing algorithm becomes:

(5)

where the loss terms are divided by and is a normalizing factor that makes the same magnitude as without this weighting scheme (this allows the unlabeled loss weighting to remain the same).

Our third class balancing algorithm is identical to the previous method except it uses an alternate class count in Equation 5. We define using only the high confidence pseudo-labeled samples above the threshold. The intuition of this third method is that each of the classes should contribute equally to the loss (i.e., each sample’s loss is divided by the number of samples of that class included in ). In practice, this method’s weights might be an order of magnitude larger than the previous method’s weights, which might contribute to training instability, so we compare both methods in Section 4.2.

3.4 Self-training iterations

Labeled and unlabeled data play different roles in semi-supervised learning. Here we propose self-training iterations where the pseudo-labels of the highest confidence unlabeled training samples are combined with labeled samples in a new iteration. Increasing the number of labeled samples per class improves performance, and substantially reduces training instability and performance variability. Although some of these pseudo-labels might be wrong, we rely on the observation that the training of deep networks are robust to small amounts of labeling noise (i.e., labeling noise of less than 10% does not harm the trained network’s performance Algan and Ulusoy (2019)). Hence, we aimed to achieve a accuracy from semi-supervised learning with the class balancing methods.

Self-training is done in BOSS by adding to the testing stage a computation of the model predictions on all of the unlabeled training data. These are sorted from the the highest prediction probabilities down and the dataset is saved. After the original training run, the labeled data can be combined with a number of the highest prediction samples from each class and a subsequent self-training iteration run can use the larger labeled dataset for retraining a new network. We experimented with labeling 5, 10, 20, and 40 of the top predictions per class and the results are reported in Section 4.3.

4 Experiments

In this Section we demonstrate that the BOSS algorithms can achieve comparable performance with fully-supervised training of Cifar-10 Krizhevsky et al. (2009) and SVHN Netzer et al. (2011). We compare our results to FixMatch111With appreciation, we acknowledge the use of the code kindly provided by the authors at https://github.com/google-research/fixmatch Sohn et al. (2020) and demonstrate the value of our approach. Our experiments use a Wide ResNet-28-2 Zagoruyko and Komodakis (2016) that matches the FixMatch reported results and we used the same cosine learning rate schedule described by Sohn, et al. Sohn et al. (2020). Our hyper-parameters were in a small range and the specifics are provided in the Appendix. For data and data augmentation, we used the default augmentation in FixMatch but our experiments did show a small improvement in using RandAugment Cubuk et al. (2019) for strong data augmentation. Our runs with fully supervised learning of the Wide ResNet-28-2 model produced a test accuracy of for Cifar-10 Krizhevsky et al. (2009) and test accuracy of for SVHN Netzer et al. (2011), which we use for our basis of comparison. We made our code available at https://github.com/lnsmith54/BOSS to facilitate replication and for use with future real-world applications.

set airplane auto bird cat deer dog frog horse ship truck Total
1
2
3
4
5
6
7
Table 1: One-shot semi-supervised average (of 2 runs) class accuracies for Cifar-10 test data with the FixMatch model, that was trained on sets of manually chosen prototypes for each class. Prototype set 6 was modified from set 2 and prototype set 7 was modified from set 4 (i.e., prototype refining).

4.1 Choosing prototypes and prototype refining

For our experiments with Cifar-10, we manually reviewed the first few hundred images and choose five sets of prototypes that we will refer to as class prototype sets 1 to 5. However, the practioner need only create one set of class prototypes and can perform prototype refining, as we describe below.

Table 1 presents the averaged (over two runs) test accuracies for each class, computed from FixMatch on the Cifar-10 test dataset for each of the prototype sets 1 to 5. This Table illustrates that a good choice of prototypes (i.e., set=3) can lead to good performance in all the classes. Table 1 also shows that for other sets the class accuracies can be quite high for some classes while low for other classes. Hence, the poor performance of some classes implies that the choice of prototypes for these classes in those sets can be improved. In prototype refining, one simply reviews the class accuracies to find which prototypes should be replaced.

We demonstrate prototype refining with two examples. The airplane and truck class accuracies in set 2 are poor so we replaced these two prototypes and name this set 6. In set 4, the cat and dog classes are performing poorly so we replaced these two prototypes and name this set 7. Table 1 shows the class accuracies for sets 6 and 7 and these results are better than the original sets. More importantly, the balancing of the accuracies across all the classes enables the use of this trained model to automatically generate labeled examples for self-training, as described in Section 3.4.

BOSS balance method Self-training
set FixMatch 1 2 3 4 +5 +10 +20 +40

1
2
3
4
5
6
7

Table 2:

Main results. BOSS methods are compared using five sets of class prototypes (i.e., 1 prototype per class) for Cifar-10, plus two sets from prototype refining. The FixMatch column shows test accuracies (average and standard deviation of 4 runs) for the original FixMatch code on the prototype sets. The next four columns gives the accuracy results for the class balance methods (see text for a description of class balance methods). Results for the PyTorch reimplementation of FixMatch and modified with the BOSS methods are shown in brackets [.]. The self-training iteration was performed with the top pseudo-labels from the run shown in bold and the results are in the next four columns.

4.2 Class balancing

In this Section we report the results from FixMatch and demonstrate substantial improvements with the class balancing methods in BOSS. Table 2 presents our main results, which illustrates the benefits from prototype refining, class balancing, and one self-training iteration. The rows in the table list the results for five sets of class prototypes (i.e., 1 prototype per class) for Cifar-10. Rows for sets 6 and 7 provides the results for prototype refining of the original sets 2 and 4, respectively. The FixMatch column shows results (i.e., average and standard deviation over four runs) for the original FixMatch code on the prototype sets. The number within brackets [.] are results from a PyTorch reimplementation of FixMatch, that we discuss below.

The next four columns presents the BOSS results with class balancing methods. As described in Section 3.3, class balance method 1 represents oversampling of minority classes, balance methods 2 and 3 are two forms of class-based loss weightings, and balance = 4 is a hybrid that combines balance methods 1 and 3. The use of class balancing significantly improves on the original FixMatch results, with increases of up to 20 absolute percentage points. Generally, the hybrid class balance method 4 is best, except when instabilities hurt the performance. Crucially, the performance is generally in the 90% range with good performance across all the classes, which enables the self-training iteration.

Table 2

indicates that good class prototypes (i.e., sets 3, 6, and 7) result in test accuracies near 90% and low variance between runs. However, when some of the class prototypes are inferior, some of the of the training runs exhibit instabilities that cause lower averaged accuracies and higher variance. Other experiments in our Supplemental Materials demonstrate that in these cases, reducing the amount of class balancing reduces the instabilities (i.e., the quality of the class prototypes governs the hyper-parameter values).

PyTorch version: We have taken advantage of a PyTorch reimplementation222With appreciation, we acknowledge the use of the code provided at https://github.com/CoinCheung/fixmatch

of the original TensorFlow version of the FixMatch code to test our proposed BOSS methods in PyTorch. Table

2 reports the best test accuracies for the PyTorch version in the brackets [.].

It is clear to us that the researcher who reimplemented FixMatch in PyTorch took care to replicate FixMatch. In training with 4 labeled samples per class, his code obtained a test accuracy of for Cifar-10, compared to results of reported in the paper. However, it is also clear from our experiments and Table 2 that there are substantial differences between the TensorFlow and PyTorch versions when comparing one-shot semi-supervised learning. A possible source of this difference might come from the preprocessing step in the TensorFlow implementation. This preprocessing includes a sorting process of the unlabeled data that is not present in the PyTorch code. This preprocessing was not mentioned in Sohn, et al. and could easily be deemed inconsequential but it does seem to impact the trained network’s performance on one-shot semi-supervised learning. Hence, we observe that the sensitivity of one-shot semi-supervised learning reveals even minor differences that are invisible in fully supervised learning.

The PyTorch implementation also shows that the class balancing methods improve the test accuracy over FixMatch. In particular, class balance method 1 (i.e., oversampling) appears to improve the test accuracy more than the other methods.

4.3 Self-training iterations

The final four columns of Table 2 list the results of performing one self-training iteration. The self-training was initialized with the original single labeled sample per class, plus the most confident pseudo-labeled examples from the BOSS training run that is highlighted in bold. For example, the ‘+5’ columns means that five pseudo-labeled examples per class were combined with the original labeled prototypes to make a set with a total of 60 labeled examples. These self-training results demonstrate that one-shot semi-supervised learning can reach comparable performance to the results from fully supervised training (i.e., 94.9%), often with adding as few as 5 samples per class. However, we expect that in practice, self-training by adding more samples per class will prove more reliable.

BOSS balance method self-training
set FixMatch 1 2 3 4 +5 +10 +20 +40
1
2
3
4

Table 3: BOSS methods are compared using four sets of class prototypes (i.e., 1 prototype per class) for SVHN. The FixMatch column shows results for the original FixMatch code on the prototype sets. The next four columns gives the accuracy results for the class balance methods Results are an average of test accuracies for four runs. The self-training iteration was performed on the results from the class balancing shown in bold.

4.4 Svhn

SVHN is obtained from house numbers in Google Street View images and is used for recognizing digits (i.e., 0 – 9) in natural scene images. Visual review of the images show that the training samples are of poor quality (i.e., blurry) and often contain distractors (i.e., multiple digits in an image). Because of the quality issue, we needed to review several hundred unlabled training samples in order to find four class prototype sets.

Even though the SVHN training images are of poorer quality than the Cifar-10 training images, one-shot semi-supervised learning with FixMatch on sets of prototypes produced higher test accuracies than with Cifar-10. Table 3 presents equivalent results for the SVHN dataset as reported in Table 2 for Cifar-10. Since the results for FixMatch are all above 89%, we did not perform prototype refining on any of these sets. However, here too the class balancing methods increase the test accuracies above the FixMatch results. With these four class prototype sets, class balance method 1 produces the best results. The test accuracies from balance method 1 are approximately 1% lower than the fully supervised results of The improvements from self-training were small and the best results fell about 0.5% below the results of of fully supervised training. We believe the differences between Cifar-10 and SVHN are related to the natures of the datasets.

5 Conclusions

The BOSS methodology relies on simple concepts: choosing iconic training samples with minimal background distractors, employing class balancing techniques, and self-training with the highest confidence pseudo-labeled samples. Our experiments in Section 4 demonstrate the potential of training a network with only one sample per class and have confirmed the importance of class balancing methods. BOSS bring one-shot and few-shot semi-supervised learning closer to reality for applications with large amounts of unlabeled data.

Our work provides researchers with the following observations and insights:

  1. There is evidence that labeling a large number of samples might not be required for training deep neural networks.

  2. All networks have a class imbalance problem to some degree. Examining class accuracies relative to each other provides insights into the network’s training.

  3. Each training sample can affect the training. One-shot semi-supervised learning provides a mechanism to study the atomic impact of a single sample. This opens up the opportunity to investigate the factors in a sample that help or hurt training performance.

  4. The PyTorch reimplementation of FixMatch showed substantial differences from the TensorFlow version that were not apparent when training with four or more samples per class. This sensitivity of one-shot semi-supervised learning can be used with AutoML and Neural Architecture Search (NAS) to obtain optimal hyper-parameters and models. In addition, we recommend that researchers test their novel architectures, loss and optimization functions on one-shot semi-supervised learning to better differentiate their methods.

Broader Impact

It is widely accepted that large labeled datasets are an essential component of training deep neural networks, either directly for training or indirectly via transfer learning. To the best of our knowledge, this paper is the first to demonstrate performance comparable to fully supervised learning with one-shot semi-supervised learning. Eliminating the burden of labeling massive amounts of training data creates great potential for new neural network applications that attain high performance, which is especially important when labeling requires expertise. Hence, the societal impact will be to make deep learning applications even more widespread.

From a scientific perspective, one-shot semi-supervised learning provides important insights on the intricacies of training deep neural networks. The effect of changing just one training image can significantly impact the final performance. Unlike fully supervised learning that commonly deals with the training of large datasets, this method provides a technique to gain information about the impact of a single labeled sample in training. In addition, we anticipate that further investigation into the instability issues of one-shot semi-supervised learning will lead to new understandings of training neural network.

Furthermore, the experience of training highly sensitive networks provides an educational experience on hyper-parameter tuning that carries over to easier training situations. In order to achieve convergence with one-shot semi-supervised learning, one must learn how to tune the hyper-parameters and architecture well. Similarly, we believe that utilizing one-shot semi-supervised learning with automatic methods such as AutoML and neural architecture search (NAS) will lead to better choices for hyper-parameters and architectures.

Limitations: While our work has taken valuable steps towards making one or few-shot semi-supervised learning possible for applications, a large gap still remains before this can be realized in practice, especially due to issues with stability during training and hyper-parameter sensitivity. The sensitivity of the results to choices of the hyper-parameters makes one-shot semi-supervised learning difficult to use in real-world applications. While there is a wide range of valuable applications (e.g., medical) that could benefit from semi-supervised learning, the testing of these applications is beyond the scope of this work.

While we attempted to provide a thorough investigation, there are a number of limitations in our work and several factors that we did not have sufficient time to explore. Our implementation was built on state-of-the-art FixMatch algorithm but the ideas presented here should carry over to other semi-supervised learning methods (this was not tested in our experiments). The model used in our experiments was a Wide ResNet-28-2; other architectures were not compared.

In addition, we made use of labeled test data to demonstrate the performance of BOSS. In practical settings, one has a large unlabeled dataset and one wishes to avoid burdensome manual labeling. However, the samples in the test dataset are less important than the choices for the class prototypes, so a small test dataset can be quickly created from the “discards” when searching for iconic prototypes. A small test dataset is useful for prototype refinement (i.e., deciding which class prototypes to replace) and it provides the practitioner with useful feedback on the system’s performance with a little additional effort. But even without any test data, one can utilize the pseudo-labeled class counts to decide which class prototypes should be replaced.

Furthermore, there are several assumptions that might not hold true in a practical setting. First of all, there is an implicit assumption that the unlabeled dataset is class balanced; that is, it contains the same number of samples of each class. In practical situations with large amounts of unlabeled data, this assumption is unlikely to be true. In cases where the number of unlabeled samples belonging to each class can be estimated, it is possible to adapt the class balancing methods. When the number of unlabeled samples belonging to each class is unknown, it is possible to create a small validation set in a similar manner as described above for creating a test set and utilize the validation set as a measure of class balance.

In addition, we also assume in our experiments that all of the unlabeled samples belong to one of the known classes. In practical settings, the unlabeled dataset might contain samples that don’t belong to any of the prototype classes. We did not test the situation where we use only a subset of the classes in the training datasets.

With appreciation, we acknowledge the use of the code kindly provided by the authors of FixMatch Sohn et al. (2020) at https://github.com/google-research/fixmatch and the author of a PyTorch reimplementation where the code was provided at https://github.com/CoinCheung/fixmatch. We also express our deep appreciation to Nicholas Carlini and David Berthelot for discussions of ReMixMatch and FixMatch that facilitated our use of their methods and codes.

This work was funded by the Office of Naval Research. The views and conclusions contained in this document are those of the authors and should not be interpreted as necessarily representing the official policies, either expressed or implied, of the US Navy.

References

  • [1] G. Algan and I. Ulusoy (2019) Image classification with deep learning in the presence of noisy labels: a survey. arXiv preprint arXiv:1912.05170. Cited by: §3.4.
  • [2] A. Antoniou and A. Storkey (2019) Assume, augment and learn: unsupervised few-shot meta-learning via random labels and data augmentation. arXiv preprint arXiv:1902.09884. Cited by: §2.
  • [3] D. Berthelot, N. Carlini, E. D. Cubuk, A. Kurakin, K. Sohn, H. Zhang, and C. Raffel (2019) ReMixMatch: semi-supervised learning with distribution alignment and augmentation anchoring. arXiv preprint arXiv:1911.09785. Cited by: §2.
  • [4] D. Berthelot, N. Carlini, I. Goodfellow, N. Papernot, A. Oliver, and C. A. Raffel (2019) Mixmatch: a holistic approach to semi-supervised learning. In Advances in Neural Information Processing Systems, pp. 5050–5060. Cited by: §2.
  • [5] O. Chapelle, B. Scholkopf, and A. Zien (2009) Semi-supervised learning (chapelle, o. et al., eds.; 2006)[book reviews]. IEEE Transactions on Neural Networks 20 (3), pp. 542–542. Cited by: §2.
  • [6] E. D. Cubuk, B. Zoph, J. Shlens, and Q. V. Le (2019) RandAugment: practical data augmentation with no separate search. arXiv preprint arXiv:1909.13719. Cited by: §4.
  • [7] T. Elsken, J. H. Metzen, and F. Hutter (2018) Neural architecture search: a survey. arXiv preprint arXiv:1808.05377. Cited by: §1.
  • [8] C. Finn, P. Abbeel, and S. Levine (2017) Model-agnostic meta-learning for fast adaptation of deep networks. In

    Proceedings of the 34th International Conference on Machine Learning-Volume 70

    ,
    pp. 1126–1135. Cited by: §2.
  • [9] Y. Grandvalet and Y. Bengio (2005) Semi-supervised learning by entropy minimization. In Advances in neural information processing systems, pp. 529–536. Cited by: §2.
  • [10] K. Hsu, S. Levine, and C. Finn (2018) Unsupervised learning via meta-learning. arXiv preprint arXiv:1810.02334. Cited by: §2.
  • [11] M. Hyun, J. Jeong, and N. Kwak (2020) Class-imbalanced semi-supervised learning. arXiv preprint arXiv:2002.06815. Cited by: §2.
  • [12] J. M. Johnson and T. M. Khoshgoftaar (2019) Survey on deep learning with class imbalance. Journal of Big Data 6 (1), pp. 27. Cited by: §1, §2, §2.
  • [13] G. Koch, R. Zemel, and R. Salakhutdinov (2015) Siamese neural networks for one-shot image recognition. In ICML deep learning workshop, Vol. 2. Cited by: §2.
  • [14] A. Krizhevsky, G. Hinton, et al. (2009) Learning multiple layers of features from tiny images. Cited by: §4.
  • [15] D. Lee (2013) Pseudo-label: the simple and efficient semi-supervised learning method for deep neural networks. In Workshop on challenges in representation learning, ICML, Vol. 3, pp. 2. Cited by: §2.
  • [16] X. Li, Q. Sun, Y. Liu, Q. Zhou, S. Zheng, T. Chua, and B. Schiele (2019) Learning to self-train for semi-supervised few-shot classification. In Advances in Neural Information Processing Systems, pp. 10276–10286. Cited by: §2.
  • [17] Y. Netzer, T. Wang, A. Coates, A. Bissacco, B. Wu, and A. Y. Ng (2011) Reading digits in natural images with unsupervised feature learning. Cited by: §4.
  • [18] C. Rosenberg, M. Hebert, and H. Schneiderman (2005) Semi-supervised self-training of object detection models.. WACV/MOTION 2. Cited by: §1, §2.
  • [19] M. Sajjadi, M. Javanmardi, and T. Tasdizen (2016) Regularization with stochastic transformations and perturbations for deep semi-supervised learning. In Advances in neural information processing systems, pp. 1163–1171. Cited by: §2, §2.
  • [20] L. N. Smith and A. Conovaloff (2020) Empirical perspectives on one-shot semi-supervised learning. arXiv preprint arXiv:2004.04141. Cited by: §1, §2.
  • [21] J. Snell, K. Swersky, and R. Zemel (2017) Prototypical networks for few-shot learning. In Advances in neural information processing systems, pp. 4077–4087. Cited by: §2.
  • [22] K. Sohn, D. Berthelot, C. Li, Z. Zhang, N. Carlini, E. D. Cubuk, A. Kurakin, H. Zhang, and C. Raffel (2020) FixMatch: simplifying semi-supervised learning with consistency and confidence. arXiv preprint arXiv:2001.07685. Cited by: §A.1, §1, §2, §2, §3.1, §3.1, §3.2, §4, Broader Impact.
  • [23] Y. Sun, M. S. Kamel, A. K. Wong, and Y. Wang (2007) Cost-sensitive boosting for classification of imbalanced data. Pattern Recognition 40 (12), pp. 3358–3378. Cited by: §2.
  • [24] I. Triguero, S. García, and F. Herrera (2015) Self-labeled techniques for semi-supervised learning: taxonomy, software and empirical study. Knowledge and Information systems 42 (2), pp. 245–284. Cited by: §2.
  • [25] J. E. Van Engelen and H. H. Hoos (2020) A survey on semi-supervised learning. Machine Learning 109 (2), pp. 373–440. Cited by: §2, §2, §2.
  • [26] O. Vinyals, C. Blundell, T. Lillicrap, D. Wierstra, et al. (2016) Matching networks for one shot learning. In Advances in neural information processing systems, pp. 3630–3638. Cited by: §2.
  • [27] S. Wang and X. Yao (2012) Multiclass imbalance problems: analysis and potential solutions. IEEE Transactions on Systems, Man, and Cybernetics, Part B (Cybernetics) 42 (4), pp. 1119–1130. Cited by: §2.
  • [28] Q. Xie, E. Hovy, M. Luong, and Q. V. Le (2019)

    Self-training with noisy student improves imagenet classification

    .
    arXiv preprint arXiv:1911.04252. Cited by: §2.
  • [29] S. Zagoruyko and N. Komodakis (2016) Wide residual networks. arXiv preprint arXiv:1605.07146. Cited by: §4.
  • [30] X. Zhai, A. Oliver, A. Kolesnikov, and L. Beyer (2019) S4l: self-supervised semi-supervised learning. In Proceedings of the IEEE international conference on computer vision, pp. 1476–1485. Cited by: §2.
  • [31] H. Zhang, M. Cisse, Y. N. Dauphin, and D. Lopez-Paz (2017) Mixup: beyond empirical risk minimization. arXiv preprint arXiv:1710.09412. Cited by: §2.
  • [32] X. Zhu and A. B. Goldberg (2009) Introduction to semi-supervised learning.

    Synthesis lectures on artificial intelligence and machine learning

    3 (1), pp. 1–130.
    Cited by: §2.
  • [33] X. J. Zhu (2005) Semi-supervised learning literature survey. Technical report University of Wisconsin-Madison Department of Computer Sciences. Cited by: §2.

Appendix A Appendix

a.1 Hyper-parameters

For FixMatch we used the default hyper-parameters that were specified in Sohn, et al. Sohn et al. (2020). However, in our initial experiments with the class balance methods, we found that these hyper-parameters performed poorly. Therefore, we used a different set of hyper-parameter values for FixMatch and for the BOSS methods.

Table 4 contains the hyper-parameter values used for the results reported in our paper. Additional hyper-parameter settings that were consistent over all the runs include setting kimgs = 32768 (i.e., the number of training images) and (i.e., the unlabeled loss multiplicative factor). Furthermore, we set the augment input parameter to ‘d.d.d’, which is the default data augmentation for the labeled and unlabeled data. Our early experiments with setting the augment input parameter to ‘d.d.rac’ produces small improvements so we subsequently used the default values. The balance column reflects the class balancing method used (balance = 0 corresponds to FixMatch, which does not use any class balancing method). The remaining columns specify the weight decay, learning rate, batch size, momentum, ratio of unlabeled to labeled data, confidence threshold, and change in the confidence threshold for minority classes. Details of these last three hyper-parameters are provided in the main text.

Specifically, we found that increasing the ratio of unlabeled to labeled data (from 7 to 9), weight decay (from to ) and the learning rates (from 0.03 to 0.06) improved performance. We also found that decreasing the confidence threshold from 0.95 to 0.9 improved performance but for class balancing methods 1 and 4, we left the confidence threshold at 0.95 because the class-based thresholds were lowered by these class balancing methods. We also discovered that a smaller batch-size improved performance and chose a batch size of 30 that was a multiple of the number of classes. Our experiments with momentum found a small improvement with values between 0.85 and 0.9 and settled on using 0.88 for our experiments.

As mentioned above, we tried to use the same hyper-parameters for both FixMatch and for the class balancing methods but this proved to provide an unfair comparison to one or the other. Table 5 illustrates this. This Table provides the averaged test accuracies for class prototype set 2 for the default and another choice of weight decay (WD), learning rate (LR), batch size (BS), and the ratio of the unlabeled to labeled data (). The results for the BOSS methods improve significantly by tuning the hyper-parameters but the performance of FixMatch is reduced substantially. So we used the default set of hyper-parameters for FixMatch and another set of hyper-parameter values for the class balance methods.

Method balance weight decay LR Batch Momentum
FixMatch 0
Cifar training 1, 4
Cifar training 2, 3
Self-training 4
SVHN training 1, 4
SVHN training 2, 3
Self-training 0

Table 4: Hyper-parameter values for each of the various steps in the training.
BOSS balance method
WD/LR/BS/ FixMatch 1 2 3 4


Table 5: Test accuracies for class prototype set 2 for two hyper-parameter settings. The hyper-parameters are weight decay (WD), learning rate (LR), batch size (BS), and the ratio of the unlabeled to labeled data ().

a.2 Implementation details

In this Section we describe the changes we made to the original FixMatch codes and provide guidance on how to replicate our experiments. This Section relies on the reader being familiar with the TensorFlow version at https://github.com/google-research/fixmatch and the PyTorch version located at https://github.com/CoinCheung/fixmatch. We provide a copy of our codes as part of our Supplemental Materials.

Modifications to the original TensorFlow version of the FixMatch code were localized. In the TensorFlow version, the primary changes were made to fixmatch.py. This includes the implementation of the four class balancing methods. In support of these methods, the code for computing the number of pseudo-labels in each class was implemented. Also, a few new input parameters were added to this file that are related to the class balancing methods. Specifically, we added the input parameter “balance” to specify the class balancing method (balance=0 acts the same as the original FixMatch code) and “delT” (i.e., ) as the amount that balance method 1 can reduce the threshold. Modifications were also made to cta/lib/train.py to compute test accuracies for each class, keep track of the best test accuracy, and output the sorted pseudo-labels for the unlabeled training data. In addition, changes to libml/data.py and libml/augment.py were required in order to accept the new prototype versions of the labeled datasets.

In addition to the code, the TensorFlow FixMatch version required several other steps that are supported by code in the scripts folder. Instructions for creating the necessary dataset files are located on the website at https://github.com/google-research/fixmatch. These instructions use programs in the scripts folder that needed to be modified in order to create the dataset files needed for the prototype sets and for self-training.

We named the prototype datasets with a ‘p’ at the end to distinguish them from the original datasets. That is, ‘cifar10’ became ‘cifar10p’ and ‘svhn’ became ‘svhnp’. Therefore, it was necessary to create scripts/cifar10_prototypes.py and scripts/svhn_prototypes.py to generate the labeled training data files. We note that to be consistent with the TensorFlow FixMatch, we used ‘seed’ as the input parameter to represent different prototype sets. It is also necessary to copy the unlabeled training and labeled training files from the cifar10/svhn file names to the cifar10p/svhnp file names and we provide shell scripts to do so.

Self-training is performed as a separate step from the first training run. The training run will have created three files containing the pseudo-labels for the unlabeled training data sorted from the most confident predictions down. The three files are the pseudo-labels, the confidences, and the true labels (used only for debug purposes). The programs scripts/cifar10_iteration.py and scripts/svhn_iteration.py are provided to combine the highest confidence pseudo-labeled examples with the labeled class prototypes and create the necessary files for the self-training run. We provide shell scripts as a template for how this is done. Once these files are created, the self-training iteration can be run.

Most of our experiments were run on a SuperMicro SuperServer with Tesla V100 GPUs. We discovered that it was important to run our experiments on only 1 GPU and all our runs using multiple GPUs performed poorly.

Modifications to the PyTorch version of the FixMatch code were simpler than for the TensorFlow code. However, the execution of this code ran almost three times longer, which greatly reduced the number of experiments we could run due to constraints on computational resources. The primary modifications for class balancing were added to label_guessor.py. Secondary modification were made to the main program in train.py to add the class balancing input parameters and arguments for the call to label_guessor. In addition, cifar.py was modified to use the class prototypes instead of random examples. It was not necessary to create class prototype files as it was with the TensorFlow version. We did not have sufficient time to test self-training with the PyTorch version.

Figure 1: An example of training to a poor local minimum (blue) and training with instabilities (red). Both end with poor test accuracies but for different reasons.
Set balance Description WD LR Accuracy (%)
1 3 Instabilities 0.06 0 1 0.9
1 3 Decrease , WD, LR 0.04 0 0.5 0.9
2 4 Instabilities 0.06 0.25 1 0.95
2 4 Decrease , WD, LR 0.04 0.1 1 0.95
4 1 Local min 0.06 0.25 1 0.9
4 1 Increase 0.06 0.3 1 0.95
4 2 Local min 0.06 0 1 0.9
4 2 Increase 0.06 0 2 0.9
4 3 Local min 0.06 0 1 0.9
4 3 Increase 0.06 0 2 0.9
5 1 Instabilities 0.06 0.25 1 0.95
5 1 Decrease 0.06 0.1 1 0.95
5 2 Instabilities 0.06 0 1 0.9
5 2 Decrease 0.06 0 0.75 0.9
5 3 Instabilities 0.06 0 1 0.9
5 3 Decrease WD, LR 0.04 0 1 0.9
Table 6: Illustration of the sensitivity to the hyper-parameters WD, LR, , and . See the text for guidance on how to tune these hyper-parameters for situations with inferior performance due to instabilities or local minimums.

a.3 Discussion of training instabilities, poor local minimum, and hyper-parameter sensitivity

In our experiments we observed sensitivity of one-shot semi-supervised learning performance to the choices for the hyper-parameters and the class prototypes sets. That is, we observed that good choices for the prototypes and prototype refining significantly reduced the instabilities and the variability of the results (i.e., few instabilities were encountered for Cifar-10 prototype sets 3, 6, and 7 so the final accuracies were higher and the standard deviations of the results were lower). In sets where the performance was inferior, there was always at least one class that performed poorly. However, we also found that the hyper-parameter values made a significant difference.

We investigated the cases of poor performance and discovered that there were two different situations. Figure 1 provides examples of test accuracies during the training for both situations. The blue curve is the test accuracy where in one training run the network learns a final test accuracy of 77% for the case of class prototype set 4, balance method 3 and the hyper-parameters correspond to those described in Section A.1 (on the other hand, another run with the same hyper-parameters produced an accuracy of 93%). We hypothesize that in this situation the network can get stuck in a poor local minimum. The red curve in Figure 1 is an example of the other case and here the test accuracy in one training run learns a final test accuracy of 65% (i.e., for class prototype set 5 and balance method 3). Clearly the behavior during training is different in this case because the training is dominated by instabilities (i.e., where the model suddenly diverges during training).

We found that these two situations are the two sides of problem and it is important when tuning the hyper-parameters to identify which one is occurring. Specifically, for the results reported in Table 2 of our main paper, the inferior results for class prototype set 4 were due to poor local minimum while the inferior results for sets 1, 2, and 5 were due to instabilities.

Our experiments imply that too much class balancing can cause the training instabilities. We hypothesize that the model struggles to classify the unlabeled examples with lower quality class prototypes but the class balancing methods force the pseudo-labeling to mislabel samples in order to have the appearance of class balance. In these cases, it is better to reduce the amount of class balancing by using a smaller value for for class balance methods 1 and 4, and using a smaller value for for class balance methods 2 and 3. In addition, we observed that decreasing weight decay (WD) and the learning rate (LR) improves performance when there were instabilities.

On the other hand, if the inferior performance is due to poor local minimum, increasing the amount of class balancing improves the performance. In this case, the accuracy increases and the standard deviation decreases by using a larger value for for class balance methods 1 and 4, and using a larger value for for class balance methods 2 and 3. In addition, we observed that increasing weight decay (WD) and the learning rate (LR) improves performance. We also observed that it helps to increase if there are instabilities and to decrease in the poor local minimum situation. Table 6 provide examples of these recommendations.

Table 6 demonstrates how to improve the results presented in our main paper, where for consistency we used the same hyper-parameter values for all of the BOSS runs. Now we show that tuning can improve the test accuracies above the values reported in the paper.

Table 6 contains results of hyper-parameter fine tuning where we reported test accuracies below 85%. We list the class prototype set (set), the BOSS class balancing method (balance), weight decay (WD), initial learning rate (LR), the change in the confidence threshold for minority classes (), the unlabeled loss multiplicative factor (), the confidence threshold (), and the final test accuracy in percent. Furthermore, we provide a short description that indicates if the training curve displays instabilities (i.e., the red curve in Figure 1) or a poor local minimum (i.e., the blue curve). Or the description points out the hyper-parameters that were tuned to improve the performance.

For example, the first row in the Table shows the results for set 1 using class balance method 3. Examination of the output from this run showed a curve resembling the red curve in Figure 1, implying the problem is one of instabilities. This calls for a decrease in , which improved the accuracy and reduced the standard deviation.

The other examples in Table 1 show improved results for both the problem of instability and for poor local minimums. The examples include modifying , weight decay, learning rate, and . In most cases the final accuracies improved substantially with small changes in the hyper-parameter values, which demonstrates the sensitivity of one-shot semi-supervised learning to hyper-parameter values.