Deep learning has been hugely successful in areas such as image classification (krizhevsky2012imagenet; he2016deep; zagoruyko2016wide; huang2017densely) and speech recognition (sak2014long; sercu2016very)
, where a large amount of labeled data is available. However, in practice it is often prohibitively expensive to create a large, high quality labeled dataset, due to lack of time, resources, or other factors. For example, the ImageNet dataset—which consists of 3.2 million labeled images in 5247 categories—took nearly two and half years to complete with the aid of Amazon’s Mechanical Turk(deng2009imagenet). Some medical tasks may require months of preparation, expensive hardware, the collaboration of many experts, and often are limited by the number of participants (miotto2016deep). As a result, it is desirable to exploit unlabeled data to aid the training of deep learning models.
This form of learning is semi-supervised learning (chapelle2006semi) (SSL). Unlike supervised learning, the aim of SSL is to leverage unlabeled data, in conjunction with labeled data, to improve performance. SSL is typically evaluated on labeled datasets where a certain proportion of labels have been discarded. There have been a number of instances in which SSL is reported to achieve performance close to purely supervised learning (laine2017temporal; miyato2017virtual; tarvainen2017mean; berthelot2019mixmatch), where the purely supervised learning model is trained on the much larger whole dataset. However, despite significant progress in this field, it is still difficult to quantify when unlabeled data may aid the performance except in a handful of cases (balcan2005pac; ben2008does; kaariainen2005generalization; niyogi2013manifold; rigollet2007generalization; singh2009unlabeled; wasserman2008statistical).
In this work, we restrict our attention to SSL algorithms which add a loss term to the neural network loss.
In this work, we restrict our attention to SSL algorithms which add a loss term to the neural network loss.
These algorithms are the most flexible and practical given the difficulties in hyperparameter tuning in the entire model training process, in addition to achieving the state-of-the-art performance.
We introduce Negative Sampling in Semi-Supervised Learning (): a simple, fast, easy to tune SSL algorithm, motivated by negative sampling/contrastive estimation (mikolov2013dist; smith2005contrastive). In negative sampling/contrastive estimation, in order to train a model on unlabeled data, we exploit implicit negative evidence, originating from the unlabeled samples: Using negative sampling, we seek for good models that discriminate a supervised example from its neighborhood, comprised of unsupervised examples, assigned with a random (and potentially wrong) class. Stated differently, the learner learns that not only the supervised example is good, but that the same example is locally optimal in the space of examples, and that alternative examples are inferior. With negative sampling/contrastive estimation, instead of explaining and exploiting all of the data (that is not available during training), the model implicitly must only explain why the observed, supervised example is better than its unsupervised neighbors.
Overall, adds a loss term to the learning objective, and is shown to improve performance simply by doing so to other state-of-the-art SSL objectives. Since modern datasets often have a large number of classes (imagenet), we are motivated by the observation that it is often much easier to label a sample with a class or classes it is not, as opposed to the one class it is, exploiting ideas from negative sampling/contrastive estimation (mikolov2013dist; smith2005contrastive).
Our findings can be summarized as follows:
We propose a new SSL algorithm, which is easy to tune, and improves SSL performance of other state of the art algorithms, simply by adding the loss in their objective.
Adding the loss to the state-of-the-art—non-Mixup (zhang2017mixup)—loss for unlabeled data, i.e., Virtual Adversarial Training (VAT) (miyato2017virtual), we observe superior performance compared to state-of-the-art alternatives, such as Pseudo-Label (lee2013pseudo), plain VAT (miyato2017virtual), and VAT with Entropy Minimization (miyato2017virtual; oliver2018realistic), for the standard SSL benchmarks of SVHN and CIFAR10.
Adding the loss to the state-of-the-art Mixup SSL, i.e., the MixMatch procedure (berthelot2019mixmatch), combined with MixMatch produces superior performance for the standard SSL benchmarks of SVHN, CIFAR10 and STL-10.
Namely, adding the loss to existing SSL algorithms is an easy way to improve performance, and requires limited extra computational resources for hyperparameter tuning, since it is interpretable, fast, and sufficiently easy to tune.
2 Negative Sampling in Semi-Supervised Learning
Let the set of labeled samples be denoted as , being the input and being the associated label, and the set of unlabeled samples be denoted as , each with unknown correct label
. For the rest of the text, we will consider the cross-entropy loss, which is one of the most widely used loss functions for classification. The objective function for cross entropy loss over the labeled examples is:
where there are labeled samples, classes, is the identity operator that equals 1 when , and
is the output of the classifier for samplefor class .
For the sake of simplicity, we will perform the following relabeling: for all , and . In the hypothetical scenario where the labels for the unlabeled data are known and for the parameters of the model, the likelihood would be:
which follows from the definition of the quantities
that represent a probability distribution and, consequently, sum up to one.
Taking negative logarithms allows us to split the loss function into two components: the supervised part and the unsupervised part. The log-likelihood loss function can now be written as follows:
While the true labels need to be known for the unsupervised part to be accurate, we draw ideas from negative sampling/contrastive estimation (mikolov2013dist; smith2005contrastive) in our approach. I.e., for each unlabeled example in the unsupervised part, we randomly assign labels from the set of labels. These labels indicate classes that the sample does not belong to: as the number of labels in the task increase, the probability of including the correct label in the set of labels is small. The way labels are selected could be uniformly at random or by using Nearest Neighbor search, or even based on the output probabilities of the network, with the hope that the correct label is not picked. Our idea is analogous to the word2vec setting (mikolov2013dist), which is described in Appendix A.
The approach above assumes the use of the full dataset, both for the supervised and unsupervised parts. In practice, more often than not we train models based on stochastic gradient descent, and we implement a mini-batch variant of this approach with different batch sizesand for labeled and unlabeled data, respectively. Particularly, for the supervised mini-batch of size for labeled data, the objective term is approximated as:
The unsupervised part with mini-batch size of and loss, where each unlabeled sample is connected with hopefully incorrect labels, is approximated as:
Based on the above, our loss looks as follows:
Thus, the loss is just an additive loss term that can be easily included in many existing SSL algorithms, as we show next. For clarity, a pseudocode implementation of the algorithm where negative labels are identified by the label probability being below a threshold, as the output of the classifier or otherwise, is given in Algorithm 1.
3 Related Work
In this paper, we restrict our attention to a subset of SSL algorithms which add a loss to the supervised loss function. These algorithms tend to be more practical in terms of hyperparameter tuning (berthelot2019mixmatch). There are a number of SSL algorithms not discussed in this paper, including "transductive" models (joachim1999trans; joachim2003trans; gammerman1998learn), graph-based methods (zhu2003semi; bengio2006label), and generative modeling (joachim2003trans; belkin2002laplacian; salak2007deepbelief; coates2011encoding; goodfellow2011spike; kingma2014semi; odena2016semi; pu2016variational; salimans2016improved). For a comprehensive overview of SSL methods, refer to chapelle2006semi, or zhu2003semi. We describe below the relevant categories of SSL in this paper.
3.1 Consistency Regularization
Consistency regularization applies data augmentation to semi-supervised learning with the following intuition: Small perturbations for each sample should not significantly change the output of the network. This is usually achieved by minimizing some distance measure between the output of the network, with and without perturbations in the input. The most straightforward distance measure is the mean squared error used by the model (laine2017temporal; sajjadi2016regularization). The model adds the distance term , where is the result of a stochastic perturbation to , to the supervised classification loss as a regularizer, with some weight.
Mean teacher (tarvainen2017mean) observes the potentially unstable target prediction over the course of training with the model approach, and proposes a prediction function parameterized by an exponential moving average of model parameter values. Mean teacher adds the distance function , where is an exponential moving average of , to the supervised classification loss with some weight. However, these methods are domain specific.
3.1.1 Virtual Adversarial Training
Virtual Adversarial Training (miyato2017virtual) (VAT) approximates perturbations to be applied over the input to most significantly affect the output class distribution, inspired by adversarial examples (goodfellow2015explain; szegedy2014intriguing). VAT computes an approximation of the perturbation as:
where is an input data sample, is its dimension, is a non-negative function that measures the divergence between two distributions, and are scalar hyperparameters. Consistency regularization is then used to minimize the distance between the output of the network, with and without the perturbations in the input. Since we follow the work in oliver2018realistic almost exactly, we select the best performing consistency regularization SSL method in that work, VAT, for comparison and combination with for non-Mixup SSL; Mixup procedure will be described later.
3.2 Entropy minimization
The goal of entropy minimization (grandvalet2005entmin) is to discourage the decision boundary from passing near samples where the network produces low-confidence predictions. One way to achieve this is by adding a simple loss term to minimize the entropy for unlabeled data with total classes:
Entropy minimization on its own has not demonstrated competitive performance in SSL, however it can be combined with VAT for stronger results (miyato2017virtual; oliver2018realistic). We include entropy minimization with VAT in our experiments.
Pseudo-Labeling (lee2013pseudo) is a simple and easy to tune method which is widely used in practice. For a particular sample, it requires only the probability value of each class, the output of the network, and labels the sample with a class if the probability value crosses a certain threshold. The sample is then treated as a labeled sample with the standard supervised loss function. Pseudo-Labeling is closely related to entropy minimization, but only enforces low-entropy predictions for predictions which are already low-entropy. We emphasize here that the popularity of Pseudo-Labeling is likely due to its simplicity and limited extra cost for hyperparameter search.
3.4 Mixup based SSL
Mixup (zhang2017mixup) combines pairs of samples and their one-hot labels with the following operations
to produce a new sample where is a hyperparameter. Mixup is a form of regularization which encourages the neural network to behave linearly between training examples, justified by Occam’s Razor. In SSL, the labels are typically the predicted labels by a neural network with some processing steps.
Applying Mixup to SSL led to Interpolation Consistency Training (ICT)(verma2019ict) and MixMatch (berthelot2019mixmatch), which significantly improved upon previous results with SSL on the standard benchmarks of CIFAR10 and SVHN.
ICT trains the model to output predictions similar to a mean-teacher , where is an exponential moving average of . Namely, on unlabeled data, ICT encourages .
MixMatch applies a number of processing steps for labeled and unlabeled data on each iteration and mixes both labeled and unlabeled data together. The final loss is given by
where is the labeled data , is the unlabeled data , and are the output samples labeled by MixMatch, and , , , are hyperparameters. Given a batch of labeled and unlabeled samples, MixMatch applies data augmentations on each unlabeled sample , averages the predictions across the augmentations,
and applies temperature sharpening,
to the average prediction. is typically 2 in practice, and is 0.5. The unlabeled data is labeled with this sharpened average prediction.
Let the collection of labeled unlabeled data be . Standard data augmentation is applied to the originally labeled data and let this be denoted . Let denote the shuffled collection of and . MixMatch alters Mixup by adding a max operation
and produces and .
Since MixMatch performs the strongest empirically, we select MixMatch as the best performing Mixup-based SSL method for comparison and combination with .
We separate experiments into non-Mixup-based SSL and Mixup-based SSL. For reproducibility, we follow the methodology of oliver2018realistic almost exactly for our non-Mixup-based SSL experiments, and reproduce many of the key findings. In that work, they compare non-Mixup-based SSL methods and, thus, we use their consistent testbed for the same purpose. MixMatch uses an almost identical setup, but with a slightly different evaluation method, and we use the official implementation for MixMatch for our Mixup-based experiments.
4.1 Non-Mixup-based SSL
Following oliver2018realistic, the model employed is the standard Wide ResNet (WRN) (zagoruyko2016wide)
with depth 28 and width 2, batch normalization(ioffe2015batch)
, and leaky ReLU activations(maas2013rectifier). The optimizer is the Adam optimizer (kingma2014adam). The batch size is 100, half of which are labeled and half are unlabeled. Standard procedures for regularization, data augmentation, and preprocessing are followed.
We use the standard training data/validation data split for SVHN, with 65,932 training images and 7,325 validation images. All but 1,000 examples are turned ”unlabeled”. Similarly, we use the standard training/data validation data split for CIFAR10, with 45,000 training images and 5,000 validation images. All but 4,000 labels are turned ”unlabeled”. We also use the standard training data/validation data split for CIFAR100, with 45,000 training images and 5,000 validation images. All but 10,000 labels are turned ”unlabeled”.
Hyperparameters are optimized to minimize validation error; test error is reported at the point of lowest validation error. We select hyperparameters which perform well for both SVHN and CIFAR10. After selecting hyperparameters on CIFAR10 and SVHN, we run almost the exact same hyperparameters with practically no further tuning on CIFAR100 to determine the ability of each method to generalize to new datasets. Since VAT and VAT + EntMin use different hyperparameters for CIFAR10 and SVHN, we use those tuned for CIFAR10 for the CIFAR100 dataset. For and + VAT, we divide the threshold by 10 since there are 10x classes in CIFAR100. We run 5 seeds for all cases.
Since models are typically trained on CIFAR10 (krizhevsky2009cifar10) and SVHN (netzer2011svhn)
for fewer than the 500,000 iterations (1,000 epochs)(oliver2018realistic), we make the only changes of reducing the total iterations to 200,000, warmup period (tarvainen2017mean) to 50,000, and iteration of learning rate decay to 130,000. All other methodology follows that work (oliver2018realistic).
4.1.1 Baseline Methods
For baseline methods, we consider Pseudo-Labeling, due to its simplicity on the level of , and VAT for its performance, in addition to VAT + Entropy Minimization. We omit the model and Mean Teacher, since we follow the experiments of oliver2018realistic and both produce worse performance than VAT. The supervised baseline is trained on the remaining labeled data after some labels have been removed. We generally follow the tuned hyperparameters in the literature and do not observe noticeable gains from further hyperparameter tuning.
4.1.2 Implementation of
We implement using the output probabilities of the network with the unlabeled samples, namely
The performance of with random negative sampling assignment or Nearest Neighbor-based assignment is given in Section 5. We label a sample with negative labels for the classes whose probability value falls below a certain threshold. We then simply add the loss to the existing SSL loss function. Using on its own gives
for some weighting . For adding to VAT, this gives
for some weighting . The weighting is a common practice in SSL, also used in MixMatch. This is the simplest form of and we believe there are large gains to be made with more complex methods of choosing the negative labels.
|Dataset||Supervised||Pseudo-Label||VAT||VAT + EntMin||VAT +|
|CIFAR10||20.76 .28||17.56 .29||14.72 .23||14.34 .18||16.03 .05||13.94 .10|
|SVHN||12.39 .53||7.70 .22||6.20 .11||6.10 .02||6.52 .22||5.51 .14|
|CIFAR100||48.26 .25||46.91 .31||44.38 .56||43.92 .44||46.34 .37||43.70 .19|
We follow the practice in oliver2018realistic and use the same hyperparameters for plain and in + VAT for both CIFAR10 and SVHN. After selecting hyperparameters on CIFAR10 and SVHN, we run almost the exact same hyperparameters with little further tuning on CIFAR100, where the threshold is divided by 10 since there are 10x classes in CIFAR100.
CIFAR10. We evaluate the accuracy of each method with 4,000 labeled samples and 41,000 unlabeled samples, as is standard practice. The results are given in Table 1. For , we use a threshold , learning rate of 6e-4, and . For VAT + , we use a shared learning rate of 6e-4 and reduce from 1 to 0.3, which is identical to . All other settings remain as is optimized individually.
We created 5 splits of 4,000 labeled samples, each with a different seed. Each model is trained on a different split and test error is reported with mean and standard deviation. We find thatperforms reasonably well and significantly better than Pseudo-Labeling, over a 1.5% improvement. A significant gain over all algorithms is attained by adding the loss to the VAT loss. VAT + achieves almost a 1% improvement over VAT, and is about 0.5% better than VAT + EntMin. This underscores the flexibility of to improve existing methods.
SVHN. We evaluate the accuracy of each method with 1,000 labeled samples and 64,932 unlabeled samples, as is standard practice. The results are shown in Table 1. We use the same hyperparameters for and VAT + as in CIFAR10.
Again, 5 splits are created, each with a different seed. Each model is trained on a different split and test error is reported with mean and standard deviation. Here, achieves competitive learning rate with VAT, 6.52% versus 6.20%, and is significantly better than Pseudo-Labeling, at 7.70%. By combining with VAT, test error is further reduced by a notable margin, almost 1% better than VAT alone and more than 0.5% better than VAT + EntMin.
CIFAR100. We evaluate the accuracy of each method with 10,000 labeled samples and 35,000 unlabeled samples, as is standard practice. The results are given in Table 1. For , we use a threshold , learning rate of 6e-4, and , following the settings in CIFAR10 and SVHN. For VAT + in CIFAR100, we use a shared learning rate of 3e-3 and , .
As before, we created 5 splits of 10,000 labeled samples, each with a different seed, and each model is trained on a different split. Test error is reported with mean and standard deviation. is observed to improve 0.6% test error over Pseudo-Labeling and adding to VAT reduces test error slightly and achieves the best performance. This suggests that EntMin and boosts VAT even with little hyperparameter tuning, and perhaps should be used as default. We note that the performance of SSL methods can be sensitive to hyperparameter tuning, and minor hyperparameter tuning may improve performance greatly. In our experiments, alone runs more than 2x faster than VAT.
4.2 Mixup-based SSL
We follow the methodology of berthelot2019mixmatch and continue to use the same model described in section 4.1. In the previous section, we use the standard training data/validation data split for SVHN and CIFAR10, with all but 1,000 labels and all but 4,000 labels discarded respectively. Since the performance of MixMatch is particularly strong using only a small number of labeled samples, we include experiments for SVHN with all but 250 labels discarded, and CIFAR10 with all but 250 labels discarded. We also include experiments on STL10, a dataset designed for SSL, which has 5,000 labeled images and 100,000 unlabeled images drawn from a slightly different distribution than the labeled data. All but 1,000 labels are discarded for STL10.
Hyperparameters are tuned individually for each dataset, and the median of the last 20 checkpoints’ test error is reported, following berthelot2019mixmatch. We run 5 seeds.
Again, we reduce training epochs to 300 epochs for both SVHN and CIFAR10, which is a typical training time for fully supervised models. We reduce the training epochs of STL10 significantly in interest of training time. All other methodology follows the work of MixMatch. We note here that berthelot2019mixmatch differs from oliver2018realistic in that it evaluates an exponential moving average of the model parameters, as opposed to using a learning rate decay schedule, and uses weight decay.
4.2.1 Baseline Methods
We run MixMatch with the official implementation, and use the parameters recommended in the original work for each dataset.
4.2.2 Implementation of
Recall that MixMatch outputs collections of samples with their generated labels. We label each sample with negative labels for the classes whose generated probability value falls below a certain threshold. We then simply add the loss to the existing SSL loss function, computing the loss using the probability outputs of the network as usual. Namely,
We follow the practice of berthelot2019mixmatch and tune separately for each dataset. MixMatch + only takes marginally longer runtime than MixMatch on its own. The learning rate is fixed.
CIFAR10. We evaluate the accuracy of each method with 4,000 labeled samples and 41,000 unlabeled samples, as is standard practice, and 250 labeled samples and 44,750 unlabeled samples, where MixMatch performs much stronger than other SSL methods. The results are given in Table 2.
As in berthelot2019mixmatch, we use and . For , we use a threshold of and a coefficient of for 250 labeled samples and for 4,000 labeled samples.
We created 5 splits of the number of labeled samples, each with a different seed. Each model is trained on a different split and test error is reported with mean and standard deviation.
Similar to the previous section, we find that adding immediately improves the performance of MixMatch, with a 2% improvement with 250 labeled samples and a small improvement for 4,000 samples. The 250 labeled samples case may be the more interesting case since it highlights the sample efficiency of the method.
|MixMatch||14.49 1.60||7.05 0.10|
|Mixmatch +||12.48 1.21||6.92 0.12|
SVHN. We evaluate the accuracy of each method with 1,000 labeled samples and 64,932 unlabeled samples, as is standard practice, and 250 labeled samples and 65,682 unlabeled samples. The results are shown in Table 3.
Following the literature, we use and . For , we again use a threshold of and a coefficient of for both 250 labeled samples and 1,000 labeled samples.
We created 5 splits with 5 different seeds, where each model is trained on a different split and test error is reported with mean and standard deviation.
By adding to MixMatch, the model achieves almost the same test error with 250 labeled samples than it does using only MixMatch on 1,000 labeled samples. In other words, in this case applying improves performance almost equivalent to having 4x the amount of labeled data. In the cases of 250 labeled samples and 1,000 labeled samples, adding to MixMatch improves performance by 0.4% and 0.15% respectively, achieving state-of-the-art results.
|MixMatch||3.75 0.09||3.28 0.11|
|Mixmatch +||3.38 0.08||3.14 0.11|
STL10. We evaluate the accuracy of each method with 1,000 labeled samples and 100,000 unlabeled samples. The results are given in Table 4.
Following the literature, we use and . For , we again use a threshold of and . We trained the model for a significantly fewer epochs than in berthelot2019mixmatch, however even in this case can improve upon MixMatch, reducing test error slightly.
|Mixmatch +||21.74 0.33|
5 Alternative methods
With computational efficiency in mind, we compare several methods of implementing in Table 5
on the F-MNIST dataset with a small Convolutional Neural Network. We split the F-MNIST dataset into a 2,000/58,000 labeled/unlabeled split and report validation error at the end of training. Specifically, we compare:
Supervised: trained only on the 2,000 labeled samples.
Uniform: negative labels are selected uniformly over all classes.
NN: We use the Nearest Neighbor (NN) method to the exclude the class of the NN, exclude four classes with the NNs, or to label with the class with the furthest NN.
Threshold: refers to the method of section 4.1.2
Oracle: negative labels are selected uniformly over all wrong classes.
|Uniform - 1||18.64 .38|
|Uniform - 3||19.35 .33|
|Exclude class of NN - 1||17.12 .15|
|Exclude 4 nearest classes with NN - 1||17.13 .21|
|Furthest class with NN - 1||16.76 .15|
|Oracle - 1||16.37 .12|
|Oracle - 3||15.20 .66|
Selecting negative labels uniformly over all classes appears to hurt performance, suggesting that negative labels must be selected more carefully in the classification setting. NN methods appear to improve over purely supervised training, however the effectiveness is limited by long preprocessing times and the high dimensionality of the data.
The method described in section 4.1.2, listed here as Threshold, achieves superior test error in comparison to NN and Uniform methods. In particular, it is competitive with Oracle - 1, an oracle which labels each unlabeled sample with one negative label which the sample is not a class of.
It is no surprise that Oracle - 3 improves substantially over Oracle - 1, and it is not inconceivable to develop methods which can accurately select a small number of negative labels, and these may lead to even better results when combined with other SSL methods.
We stress that this is not a definitive list of methods to implement negative sampling in SSL, and our fast proposed method, when combined with other SSL, already improves over the state-of-the-art.
With simplicity, speed, and ease of tuning in mind, we proposed Negative Sampling in Semi-Supervised Learning (), a semi-supervised learning method inspired by negative sampling, which simply adds a loss function. We demonstrate the effectiveness of when combined with existing SSL algorithms, producing the overall best result for non-Mixup-based SSL, by combining with VAT, and Mixup-based SSL, by combining with MixMatch. We show improvements across a variety of tasks with only a minor increase in training time.
Appendix A Negative Sampling
We present the case of word2vec for negative sampling where the number of words and contexts is such that picking a random pair of (word, context) is with high probability not related. To make the resemblance, let us describe the intuition behind word2vec. Here, the task is to relate words –represented as – with contexts –represented as . We can theoretically conceptualize words being related with , and contexts being related to labels . The negative sampling by Mikolov et al., considers the following objective function: consider a pair of a word and a context. If this pair comes from valid data that correctly connects these two, then we can say that the data pair came from the true data distribution; if this pair does otherwise, then we claim that does not come from the true distribution.
In math, we will denote by as the probability that satisfies the first case, and otherwise. The paper models these probabilities as:
where correspond to the vector representation of the context and word, respectively.
Now, in order to find good vector representations (we naively group all variables into ), given the data, we perform maximum log-likelihood as follows:
Of course, we never take the whole dataset (whole corpus ) and do gradient descent; rather we perform SGD by considering only a subset of the data for the first term:
Also, we cannot consider *every* data point not in the dataset; rather, we perform negative sampling by selecting random pairs (according to some probability - this is important)—say pairs:
where the tildes represent the “non-valid” data.