RealMix: Towards Realistic Semi-Supervised Deep Learning Algorithms

12/18/2019 ∙ by Varun Nair, et al. ∙ 0

Semi-Supervised Learning (SSL) algorithms have shown great potential in training regimes when access to labeled data is scarce but access to unlabeled data is plentiful. However, our experiments illustrate several shortcomings that prior SSL algorithms suffer from. In particular, poor performance when unlabeled and labeled data distributions differ. To address these observations, we develop RealMix, which achieves state-of-the-art results on standard benchmark datasets across different labeled and unlabeled set sizes while overcoming the aforementioned challenges. Notably, RealMix achieves an error rate of 9.79 to surpass baseline performance when there is significant mismatch in the labeled and unlabeled data distributions. RealMix demonstrates how SSL can be used in real world situations with limited access to both data and compute and guides further research in SSL with practical applicability in mind.

READ FULL TEXT VIEW PDF
POST COMMENT

Comments

There are no comments yet.

Authors

page 1

page 2

page 3

page 4

This week in AI

Get the week's most popular data science and artificial intelligence research sent straight to your inbox every Saturday.

1 Introduction

Recent progress in deep learning has largely been driven by the development of specialized hardware and the abundance of large, labeled datasets. While applicable in learning tasks when data is widely and cheaply available, these techniques are impractical to solve real world problems where collecting data is both time-consuming and expensive. Typical examples of such problems include diagnosis from medical imaging and robotic perception problems.

Figure 1: A high-level illustration overview of RealMix, a novel semi-supervised learning technique improving classification performance when there is a significant shift between the distributions of the unlabeled and the labeled data.

To combat challenges in these domains, Semi-Supervised Learning (SSL) algorithms have emerged as a useful tool [3]. SSL algorithms seek to learn the underlying structure of data by utilizing large amounts of unlabeled data, which can often be more readily available than labeled data. Recent work in SSL [2, 9, 14, 15] has progressed using a number of assumptions. First, that model outputs on unlabeled data should be invariant to small perturbations (i.e. consistency training). Second, that encouraging model outputs to be more confident will steer decision boundaries away from high-density regions (i.e. entropy minimization [6]

). Finally, that the training data distribution can be extended using linear interpolations of data points (i.e. MixUp 

[18]).

SSL algorithms are typically evaluated by taking a standard benchmarking dataset (e.g. CIFAR10 [7], SVHN [10]) and discarding a significant fraction of the labels. This results in a small labeled dataset and a larger unlabeled dataset that both come from the same distribution. The current state-of-the-art SSL technique MixMatch [2] is able to recover over 92% of the test accuracy on CIFAR10 using 200 times fewer labels than the supervised baseline. These advances prompt the following question: Can SSL algorithms sufficiently alleviate the need for labeled data in real-world settings?

Oliver et al. [11]

argues that the current approach to evaluating SSL algorithms is inadequate and raises several questions about SSL’s real-world applicability. In particular, they find that performance of SSL techniques suffer when there is a significant mismatch in the unlabeled and labeled data distributions and that transfer learning can often outperform SSL with labeled data alone. We reevaluate these findings on sections 4.2.2 and 4.2.3, showing that this is no longer true.

These problems have, up until now, been a major drawback on the adoption of SSL techniques in realistic setups. We can define a realistic setup for SSL as one in which a practitioner compares SSL performance with transfer learning using limited labeled samples (given its success in pre-training classifiers 

[1]) and where unlabeled data samples are not guaranteed to come from the labeled data distribution. Our goal is to develop a deep SSL algorithm that unites successful practices in SSL and is viable in realistic setups.

We present RealMix, an SSL algorithm depicted in fig. 1 that unites ("mixes") the most successful approaches in SSL to set state-of-the-art results on benchmark datasets while surpassing baseline performance when there is significant mismatch in the unlabeled and labeled datasets. Our contributions can be summarized as follows:

  • We perform experiments to show that RealMix sets state-of-the-art results on CIFAR10 and SVHN, achieving an error rate of 9.79% on CIFAR10 using 250 labels.

  • We experimentally demonstrate that RealMix is applicable in real-world settings by showing that when the unlabeled distribution is significantly different from the labeled distribution, we can still improve on the supervised baseline performance. Notably, RealMix is the only SSL approach tested that is able to surpass baseline performance when there is significant or complete mismatch in the labeled and unlabeled distributions.

  • We demonstrate that RealMix (in addition to MixMatch [2]) surpasses transfer learning, and that transfer learning is complementary to SSL. We show this experimentally by pre-training a classifier on ILSVRC-2012 [13] and applying RealMix to further reduce the error on CIFAR10 with 250 labels to just 8.48%.

  • We also perform an ablation study on RealMix to identify the components that lead to its success in realistic scenarios.

  • We provide our implementation source code as a publicly available repository111Available at https://github.com/uizard-technologies/realmix to foster future research.

We continue our discussion of RealMix in the next section by detailing successful approaches in SSL, how RealMix unites these approaches, and what new elements are introduced by RealMix to make it work in realistic scenarios. In section 4, we carry out several experiments with RealMix that lead to state-of-the-art results on benchmark image classification datasets and demonstrate its effectiveness when unlabeled and labeled data distributions mismatch.

1:  Require:  

: deep neural network with trainable parameters

2:  Require:  : set of labeled data points
3:  Require:  : set of unlabeled data points
4:  Require:  : stochastic data augmentation function for unlabeled data
5:  Require:  : stochastic data augmentation function for consistency training
6:  Require:  

: MixUp function and Beta distribution parameter

7:  Require:  : Training signal annealing function for supervised loss and annealing schedule.
8:  Require:   Out-of-distribution masking function for unsupervised loss and masking parameter.
9:  
10:  for  in  do
11:     for  in  do
12:         
13:         
14:     end for
15:     
16:     
17:     
18:     
19:     
20:     if  then
21:         
22:     end if
23:     
24:     
25:     
26:     
27:     perform gradient descent update on using
28:  end for
29:  return  
Algorithm 1 Pseudocode for RealMix Algorithm

2 Related Work

While SSL techniques have a rich history [3, 4, 19], we focus on describing methods that recent deep variants utilize to achieve state-of-the-art and literature that considers SSL in realistic setups.

2.1 Consistency Training and Data Augmentation

Chapelle et al. [3, 4] describe the cluster assumption, in which data samples that belong to the same cluster structure are likely to belong to the same class. Unlabeled data points can then be used to better define the boundaries of these clusters, where the class of each cluster is defined by the labeled data points within. This assumption is also equivalent to the low-density assumption, in which the decision boundaries should lie in low-density regions. Consistency training (also consistency regularization) can be formulated by combining these assumptions into a regularization task: given an unlabeled data point , a classifier , and a perturbation , then = . In other words, a classifier should be invariant to small perturbations applied on the input, which is typically enforced by an additional loss term.

The choice of perturbation () induced on an unlabeled sample has varied across SSL techniques. The -Model [12], -Model [8], and Mean Teacher [14] perturb unlabeled samples using Gaussian noise and simple data augmentations (e.g. random translation and horizontal flips), while VAT [9] applies noise that adversarially affects classifier outputs. UDA [16] applies a more diverse range of augmentations, and ICT [15] and MixMatch [2] both use MixUp  [18] to train an SSL classifier to output consistent predictions on linear interpolations of data points.

RealMix performs consistency training by applying MixUp, horizontal flips, and random translation on labeled and unlabeled samples (detailed in section 3.3) We also extend our unlabeled sample distribution by creating several augmented copies (augmented using CutOut [5]).

2.2 Entropy Minimization

Entropy Minimization (EM) [6] has been applied in SSL to encourage high-confidence classifier outputs. This approach is also inspired by the low-density assumption, as a classifier with a decision boundary passing through high-density regions would make low-confidence predictions on a number of samples. VAT [9] incorporates EM as a loss term to further improve results, and MixMatch [2] and UDA[16] apply EM by sharpening the targets of unlabeled samples. We also apply EM through a sharpening function (described in section 3.2), as we find it to work well experimentally.

2.3 SSL in Realistic Contexts

Oliver et al. [11] described a number of pitfalls of current SSL algorithms and provided recommendations to practitioners for when SSL may be appropriate. We do not investigate all of their findings, but instead focus on those most pertinent to our work. Specifically, these include that SSL is most likely applicable if:

  • Transfer learning from similar domains using labeled datasets is not feasible.

  • The labeled and unlabeled data samples are drawn from the same distribution.

With the two above-mentioned points in mind, we find that:

  • RealMix surpasses performance compared to transfer learning and fine-tuning even when transfer learning from similar domains is feasible, including when the target and transfer domains share classes. We show this experimentally in section 4.2.3.

  • RealMix is capable of surpassing baseline performance even when upwards of 75% of the unlabeled data comes from a different distribution than the labeled data. We accomplish this using out-of-distribution masking, which prevents our classifier from learning on examples that are out-of-distribution. This is detailed further in section 3.4.

1:  Require:  : entropy minimization function
2:  Require:  : batch of unlabeled samples
3:  
4:  
5:  
6:  
7:  
8:  
9:  return  
Algorithm 2 Pseudocode for generating targets

3 RealMix

As discussed in section 1 and 2, RealMix unites the most successful approaches in SSL and adapts them to work in realistic contexts. An overview for RealMix is presented in fig. 1 and algorithm 1.

Formally, given labeled samples , unlabeled samples , MixUp beta distribution parameter , out-of-distribution masking parameter , and consistency training (unlabeled loss) weight , we can obtain a classifier that minimizes eq. 2:

(1)
(2)

where is the standard cross-entropy loss on labeled samples, is the consistency training (unlabeled loss) weight, and is computed using MSE and out-of-distribution masking (see section 4.2.2) on targets of unlabeled samples. The generation of unlabeled targets is presented in algorithm 2

and we discuss hyperparameters

and in the following subsections.

3.1 Data Augmentation

Following UDA [16], we first extend our unlabeled set by applying 50 rounds of augmentations to using , where can include cropping, flipping, or stronger augmentations such as CutOut [5]. By using several augmented copies of unlabeled data, we provide our classifier with a wide range of perturbations that give more inductive biases about the data distribution.

As a part of consistency training, we compute targets for each unlabeled batch by averaging the classifier’s predicted distribution over two additional augmentations created by a separate augmentation function, denoted (as shown in algorithm 2). We settled on two augmentations as additional augmentations significantly increased training time without significantly improving results. Note that produces many copies of unlabeled data, whereas produces just 1 copy for use in generating targets.

3.2 Entropy Minimization

MixMatch [2] and UDA [16] both implement entropy minimization through a sharpening function, which we also find to be helpful. By applying this function (line 8 of algorithm 2) on the unlabeled targets , we encourage our classifier to produce low entropy predictions on unlabeled data. That is, for each class :

(3)

where is the average predicted class and is the temperature of the sharpened distribution. Intuitively, the distribution approaches a one-hot distribution as goes to 0. We find to be a good value across multiple benchmark datasets and use it in all reported experiments.

3.3 MixUp

MixUp was proposed by Zhang et al. [18] as a regularization technique to encourage high-margin decision boundaries and was utilized in SSL by ICT [15] and MixMatch  [2]. Given two samples , and Beta distribution parameter , our MixUp function generates a new sample as follows:

(4)
(5)
(6)
(7)

Following data augmentation and the generation of unlabeled sample targets, we apply MixUp separately to both the labeled samples and unlabeled samples (see lines 17-18 of algorithm 1). As in MixMatch, the resulting samples and are linear interpolations of samples from both the labeled and unlabeled collections but are weighted to more closely resemble their "original" distribution (eq. 6). In other words, are more similar to the original labeled points and are more similar to the original unlabeled points.

Method 0% 25% 50% 75% 100%
MT [14]
MM [2]
RealMix
Table 1: Results comparing error of RealMix to other SSL methods on the distribution mismatch experiment. 0% mismatch serves as the baseline in which the labeled and unlabeled data are drawn from the same distribution. While other methods steadily increase in error as amount of mismatch increases, RealMix is surprisingly able to surpass baseline performance when there is over 75% mismatch.

3.4 Out-of-Distribution Masking

To combat the effects of labeled and unlabeled samples coming from different distributions on current SSL methods (see results of Mean Teacher and MixMatch in table 1), we introduce out-of-distribution masking. The goal of out-of-distribution masking is to mask out the unlabeled samples that the classifier has the least confidence in, computing gradients only on samples that have a confidence above a moving threshold and are thus likely in-distribution samples (see fig. 4).

It is important that the threshold for masking samples is not static, as over the course of training, we found that entropy minimization tended to force confidence values on most unlabeled samples above a specified static threshold and render the threshold useless. To find a dynamic threshold for each training step, we specify a hyperparameter that dictates what percentage of unlabeled samples to mask. We then exclude samples that have confidence values in the bottom % from training. Intuitively, can be thought of as the level of "noise" present in the unlabeled dataset.

Out-of-distribution masking helps to make RealMix extremely effective at mitigating unlabeled data mismatch, as RealMix is able to maintain performance above a supervised baseline no matter the amount of induced mismatch (see table 1 and fig. 2). We also perform an ablation on in table 6 to show that out-of-distribution masking boosts performance even if the optimal value is not found.

Figure 2: Error rate comparison of RealMix to other state of the art methods on the distribution mismatch experiment. All of the experiments are run using 6 animal classes from CIFAR10 with 400 samples per class as labeled data, and varying the overlap of animal classes that make up the unlabeled data. For example, at 0% mismatch the unlabeled distribution is made up of 4 animal classes and at 100% mismatch the unlabeled distribution is made up of 4 non-animal classes. We present supervised baseline results using the 2400 labeled samples, which achieves an error rate of 20.32%.

3.5 Training Signal Annealing

Semi-supervised learning algorithms have been evaluated on labeled data set sizes as few as 250 labels while the unlabeled data collections are often orders of magnitude larger [2, 11, 15, 16]. To mitigate the effects of overfitting to such small quantities of labeled data samples, Xie et al. [16] introduces training signal annealing (TSA). TSA delays the release of training signal based on a training schedule (logistic, linear, exponential) to limit training on labeled samples that the classifier is already confident about. We find TSA to help with training on 250 labeled samples or less using a linear schedule.

4 Experiments

In the following sections, we show RealMix’s performance on benchmark datasets, a distribution mismatch experiment, comparison to transfer learning, and an ablation study on its components.

To allow for comparison with prior SSL techniques, we follow the WRN-28-2 architecture [17], hyperparameter selection (for and ), and evaluation procedure described by Berthelot at al. [2] (which uses weight decay and an exponential moving average of model parameters). A key difference is that we train only for 500k iterations and use only 1 GPU, similarly to Oliver et al. [11]

to emulate a more realistic training setup. We report uncertainty values according to the standard deviation across 2 random seeds where possible. We also base our code implementation of RealMix and other SSL methods presented in this paper off of those created by Berthelot at al. 

[2] in order to provide the research community with reproducible results.

4.1 Baselines

We report baseline results for the CIFAR10 and SVHN experiments from the -Model [8], VAT [9], Mean Teacher [14] from those presented in  [2], and re-run MixMatch [2] and RealMix according the settings described in the previous section. For the distribution mismatch experiment (illustrated in fig. 2), we additionally re-run and report results for Mean Teacher.

4.2 Results

4.2.1 CIFAR10 and SVHN

We compare RealMix and prior SSL methods on the benchmark datasets CIFAR10 and SVHN, with results visible in table 2, fig. 3 and table 3. The typical evaluation method for SSL methods is to discard all but a number of labels, reporting performance across varying labeled set sizes. For CIFAR10, we evaluate RealMix and MixMatch [2] on 4 labeled set sizes (250, 500, 1000, 4000) and present the results found by Berthelot et al. [2] for -Model, VAT, Mean Teacher. For SVHN, we evaluate RealMix and MixMatch on 2 labeled set sizes (250, 4000) and compare them with results found by Berthelot et al.[2] for -Model, VAT, Mean Teacher. Note that these 3 models are run for 500k iterations more than the RealMix and MixMatch experiments, leaving room for further improvement on RealMix given a larger training budget.

We find that RealMix sets a new state-of-the-art on CIFAR10 with 250 labels, with an error rate of % and 17% reduction in error from the current state-of-the-art MixMatch. Compared to the fully-supervised baseline with an error rate of %, RealMix is able to use 200x fewer labels to capture over 94% of the test accuracy. We also find that RealMix is competitive with MixMatch on SVHN across labeled set sizes.

Method 250 Labels 4000 Labels
-Model  [8]
VAT  [9]
Mean Teacher [14]
MixMatch [2]
RealMix
Table 2: Results comparing error of RealMix to other SSL methods on CIFAR10 with 250 and 4000 labeled samples. The supervised baseline trained on all 50000 CIFAR10 samples achieves error of 4.48%.
Figure 3: Results of SSL algorithms on CIFAR10 across varying labeled set sizes. Note that results -Model, VAT, Mean Teacher come from Berthelot et al. [2] which are run for 500k iterations more than RealMix and MixMatch experiments. RealMix achieves state-of-the-art performance on CIFAR10 with 250 labels with an error rate of %, while the supervised baseline trained on all 50000 CIFAR10 samples achieves error of 4.48%.
Method 250 Labels 4000 Labels
-Model [8]
VAT  [9]
Mean Teacher [14]
MixMatch [2]
RealMix
Table 3: Results comparing error of RealMix to other SSL methods on SVHN with 250 and 4000 labeled samples. The supervised baseline trained on all 73257 SVHN samples achieves error of 2.72%.

4.2.2 Distribution Mismatch

Oliver et al.[11] introduced a distribution mismatch experiment using CIFAR10 to evaluate the robustness of SSL methods to out-of-distribution examples in unlabeled data. By evaluating robustness to mismatch, a practitioner can determine in which situations SSL may be preferable to using labeled samples alone.

CIFAR10 contains two sets of classes: animals (bird, cat, deer, dog, frog, horse) and transportation (airplane, automobile, ship, truck). We simulate a mismatch by making the labeled distribution consist of the 6 animal classes each with 400 labels and varying the overlap of animal classes that make up the unlabeled distribution. For example, at 0% mismatch the unlabeled distribution consists of 4 classes that are all animals and at 100% mismatch, the unlabeled distribution consists of the 4 transportation classes. We evaluate RealMix, MixMatch, and Mean Teacher on varying levels of mismatch (0%, 25%, 50%, 75%, 100%) and present our results in fig. 2.

Surprisingly, RealMix is able to surpass baseline performance on the 6 animal classes alone at all levels of mismatch. Our ablation study (results in table 6) shows that RealMix is robust to unlabeled distribution mismatch as a result of out-of-distribution masking. Both MixMatch and Mean Teacher are able to surpass baseline performance with limited mismatch, but perform far worse when more significant amounts of mismatch (75% and 100%) are introduced.

Notably, RealMix is able to surpass baseline perfomance even when the unlabeled classes share no overlap with labeled classes. This would suggest that the classifier is able to learn from unlabeled data that is out-of-distribution, which we hypothesize to be the result of MixUp [18] generating new samples that are still "slightly" in-distribution. We also selected values of the hyperparameter for as and respectively for the levels of mismatch to represent the expected percentage of unlabeled mismatch. These values were not tuned and RealMix’s performance on this experiment could presumably improve further. We also hope that future work in SSL considers out-of-distribution robustness as a key evaluation, as it not always true in real-world settings that unlabeled and labeled data arise from the same distribution.

Figure 4: Illustration of the out-of-distribution masking process. RealMix produces both the confidence on each of the images and the threshold that should be applied to them based on and the confidence values of that given batch (In this example ) . Only the images with a confidence above this dynamic threshold contribute to the unsupervised loss.

4.2.3 Transfer Learning

Transfer learning is often an attractive first option when faced with limited quantities of labeled data, which we study following the findings of Oliver et al.[11] that transfer learning may be a preferable alternative to SSL. We pre-trained a classifer on ILSVRC-2012 [13] downsampled to 32x32 and then fine-tuned it on CIFAR10 at 250 and 4000 labels.

We find that RealMix (as well as MixMatch) outperform transfer learning and finetuning on labeled data alone, even when there is overlap in the CIFAR10 and ILSVRC-2012. This suggests that the error rates of and are upper bounds on the performance using transfer learning and fine-tuning. We also find transfer learning can be complementary to SSL. Specifically, we set a new state-of-the-art on CIFAR10 with 250 labels and reduce the error rate to just %. We also attempted transfer learning on SVHN, and found that SSL methods performed far better than transfer learning - likely because the datasets are quite different.

Method 250 Labels 4000 Labels
TL & Fine-Tuning 20.60 8.45
MixMatch[2] 11.78 6.45
RealMix 9.78 6.39
RealMix + TL
Table 4: Results comparing error of RealMix to transfer learning (from ILSVRC-2012) on CIFAR10 with 250 and 4000 labeled samples. We find that not only are recent SSL methods and RealMix able to surpass transfer learning alone, but combining transfer learning with RealMix can further improve results.

4.2.4 Ablation

We finally perform an ablation study on two components of the RealMix algorithm: data augmentation and out-of-distribution masking (referred to as and respectively, in section 3.1 and algorithm 1).

RealMix extends unlabeled samples using 50 copies of samples augmented with CutOut[5], which gives us the state-of-the-art error rate on CIFAR10 with 250 labels of . Using a simpler augmentation (random translation and horizontal flips) and using fewer augmented copies both give slightly weaker results (as listed in table 5), suggesting that performing targeted augmentations and making more augmented copies of unlabeled data may further improve results.

In section 4.2.2 we study the effects of distribution mismatch on RealMix and claim that this is due to our use of out-of-distribution masking. In fact, table 6 shows that RealMix’s ability to meet or decrease the baseline error rate of % is indeed linked to out-of-distribution masking, and without it, error increases markedly.

Method CIFAR10 on 250 Labels
RealMix 9.79
RealMix w/ Simple Aug 10.42
RealMix w/ 25 Augs 10.80
Table 5: Results from ablation experiments on the augmentation type and amount from CIFAR10 on 250 labels. RealMix uses CutOut[5] to generate 50 copies of unlabeled data.
Method OOD w/ 75% Mismatch
RealMix () 16.99
RealMix () 20.73
RealMix w/o OODMask 22.70
Table 6: Results from ablation experiments on out-of-distribution masking on the experiment from table 1 with 75% mismatch. Using OODMask, RealMix meets or surpasses the supervised baseline performance (error of 20.32%) at multiple values of .

5 Conclusion

In this work we presented RealMix, a novel semi-supervised learning technique to improve classification performance even under situations when there is a significant shift between the distributions of the unlabeled and the labeled data. RealMix is, to the best of our knowledge, the only SSL approach that is able to maintain baseline performance when there is a complete mismatch in the labeled and unlabeled distributions. This is a particularly important contribution when considering the applicability of semi-supervised learning outside of academic settings where data is scarce and often noisy.

We demonstrated that RealMix achieves state-of-the-art performance on common semi-supervised learning benchmarks such as CIFAR10 and SVHN, notably achieving an error rate of 9.79% on CIFAR10 using 250 labels.

Additionally, we showed that using transfer learning techniques compliments our method to further reduce the error on CIFAR10 with 250 labels to just 8.48%.

We hope that these results illustrate the practicality of semi-supervised learning in real world settings, and alongside the provided source code, will foster future research to further advance semi-supervised learning techniques.

References