Semi-supervised learning has proven to be a powerful paradigm for leveraging unlabeled data to mitigate the reliance on large labeled datasets. In this work, we unify the current dominant approaches for semi-supervised learning to produce a new algorithm, MixMatch, that works by guessing low-entropy labels for data-augmented unlabeled examples and mixing labeled and unlabeled data using MixUp. We show that MixMatch obtains state-of-the-art results by a large margin across many datasets and labeled data amounts. For example, on CIFAR-10 with 250 labels, we reduce error rate by a factor of 4 (from 38 a factor of 2 on STL-10. We also demonstrate how MixMatch can help achieve a dramatically better accuracy-privacy trade-off for differential privacy. Finally, we perform an ablation study to tease apart which components of MixMatch are most important for its success.READ FULL TEXT VIEW PDF
Code for "MixMatch - A Holistic Approach to Semi-Supervised Learning"
Pytorch Implementation of the paper MixMatch: A Holistic Approach to Semi-Supervised Learning (https://arxiv.org/pdf/1905.02249.pdf)
Reimplementation of "Realistic Evaluation of Deep Semi-Supervised Learning Algorithms"
Code for the paper: RealMix: Towards Realistic Semi-Supervised Deep Learning Algorithms
Much of the recent success in training large, deep neural networks is thanks in part to the existence of large labeled datasets. However, collecting labeled data is expensive for many learning tasks because it necessarily involves expert knowledge. This is perhaps best illustrated by medical tasks where measurements are made with expensive machinery and labels are the fruit of a time-consuming analysis, often drawing from the conclusions of multiple human experts. Furthermore, data labels may contain sensitive information that may be considered private. In comparison, in many tasks it is much easier or cheaper to obtain unlabeled data.
Semi-supervised learning chapelle2006semi (SSL) seeks to largely alleviate the need for labeled data by allowing a model to leverage unlabeled data. Many recent approaches for semi-supervised learning add a loss term which is computed on unlabeled data and encourages the model to generalize better to unseen data. In much recent work, this loss term falls into one of three classes (discussed further in Section 2): entropy minimization grandvalet2005semi; lee2013pseudo—which encourages the model to output confident predictions on unlabeled data; consistency regularization—which encourages the model to produce the same output distribution when its inputs are perturbed; and generic regularization—which encourages the model to generalize well and avoid overfitting the training data.
In this paper, we introduce , an SSL algorithm which introduces a single loss that gracefully unifies these dominant approaches to semi-supervised learning. Unlike previous methods, targets all the properties at once which we find leads to the following benefits:
Experimentally, we show that obtains state-of-the-art results on all standard image benchmarks (section 4.2), for example obtaining a 11.08% error rate on CIFAR-10 with 250 labels (compared to the next-best-method which achieved 38%);
Furthermore we show in ablation study that is greater than the sum of its parts;
We demonstrate in section 4.3 that is useful for differentially private learning, enabling students in the PATE framework papernot2016semi to obtain new state-of-the-art results that simultaneously strengthen privacy guarantees provided and the accuracy achieved.
In short, introduces a unified loss term for unlabeled data that seamlessly reduces entropy while maintaining consistency and remaining compatible with traditional regularization techniques.
To set the stage for , we first introduce existing methods for SSL. We focus mainly on those which are currently state-of-the-art and that builds on; there is a wide literature on SSL techniques that we do not discuss here (e.g., “transductive” models gammerman1998learning; joachims2003transductive; joachims1999transductive, graph-based methods zhu2003semi; bengio2006label, generative modeling Belkin+Niyogi-2002; LasserreJ2006; Russ+Geoff-nips-2007; Coates2011b; Goodfellow2011; kingma2014semi; pu2016variational; odena2016semi; salimans2016improved, etc.). More comprehensive overviews are provided in zhu2003semi; chapelle2006semi. In the following, we will refer to a generic model which produces a distribution over class labels for an input with parameters .
A common regularization technique in supervised learning is data augmentation, which applies input transformations assumed to leave class semantics unaffected. For example, in image classification, it is common to elastically deform or add noise to an input image, which can dramatically change the pixel content of an image without altering its label ciresan2010deep; simard2003best; cubuk2018autoaugment. Roughly speaking, this can artificially expand the size of a training set by generating a near-infinite stream of new, modified data. Consistency regularization applies data augmentation to semi-supervised learning by leveraging the idea that a classifier should output the same class distribution for an unlabeled example even after it has been augmented. More formally, consistency regularization enforces that an unlabeled example should be classified the same as , where is a stochastic data augmentation function—like a random spatial translation or adding noise.
In the simplest case, the “-Model” laine2016temporal (also called “Regularization with stochastic transformations and perturbations” sajjadi2016regularization) adds the loss term
for unlabeled datapoints . Note again that is a stochastic transformation, so the two terms in eq. 1 are not identical. This approach has been applied to image classification benchmarks using a sophisticated augmentation process which includes rotation, shearing, additive Gaussian noise, etc. “Mean Teacher” tarvainen2017weight replaces one of the terms in eq. 1 with the output of the model using an exponential moving average of model parameter values. This provides a more stable target and was found empirically to significantly improve results. A drawback to these approaches is that they use domain-specific data augmentation strategies. “Virtual Adversarial Training” miyato2018virtual (VAT) addresses this by instead computing an additive perturbation to apply to the input which maximally changes the output class distribution. MixMatch utilizes a form of consistency regularization through the use of standard data augmentation for images (random horizontal flips and crops).
A common underlying assumption in many semi-supervised learning methods is that the classifier’s decision boundary should not pass through high-density regions of the marginal data distribution. One way to enforce this is to require that the classifier output low-entropy predictions on unlabeled data. This is done explicitly in grandvalet2005semi by simply adding a loss term which minimizes the entropy of for unlabeled data . This form of entropy minimization was combined with VAT in miyato2018virtual to obtain stronger results. “Pseudo-Label” lee2013pseudo does entropy minimization implicitly by constructing hard labels from high-confidence predictions on unlabeled data and using these as training targets in a standard cross-entropy loss. MixMatch also implicitly achieves entropy minimization through the use of a “sharpening” function on the target distribution for unlabeled data, described in section 3.2.1.
Regularization refers to the general approach of imposing a constraint on a model to make it harder to memorize the training data and therefore hopefully make it generalize better to unseen data hinton1993keeping. A ubiquitous regularization technique is to add a loss term which penalizes the norm of the model parameters, which can be seen as enforcing a zero-mean identity-covariance Gaussian prior on the weight values hinton1993keeping. When using simple gradient descent, this loss term is equivalent to exponentially decaying the weight values towards zero. Since we are using Adam as our gradient optimizer, we use explicit “weight decay” rather than an loss term loshchilov2017fixing; zhang2018three.
More recently, the zhang2017mixup regularizer was proposed, which trains a model on convex combinations of both inputs and labels. can be seen as encouraging the model to have strictly linear behavior “between” examples, by requiring that the model’s output for a convex combination of two inputs is close to the convex combination of the output for each individual input verma2018manifold; verma2019interpolation; hataya2019unifying. We utilize in both as a regularizer (applied to labeled datapoints) and a semi-supervised learning method (applied to unlabeled datapoints). has been previously applied to semi-supervised learning; in particular, the concurrent work of verma2019interpolation uses a subset of the methodology used in MixMatch. We clarify the differences in our ablation study (section 4.2.3).
In this section, we introduce , our proposed semi-supervised learning method. is a “holistic” approach which incorporates ideas and components from the dominant paradigms for SSL discussed in section 2. Given a batch of labeled examples with corresponding one-hot targets (representing one of possible labels) and an equally-sized batch of unlabeled examples, produces a processed batch of augmented labeled examples and a batch of augmented unlabeled examples with “guessed” labels . and are then used in computing separate labeled and unlabeled loss terms. More formally, the combined loss for semi-supervised learning is computed as
where is the cross-entropy between distributions and , and , , , and
are hyperparameters described below. The fullalgorithm is provided in algorithm 1, and a diagram of the label guessing process is shown in fig. 1. We describe each part of in the following sections.
As noted in section 2.1, a common approach for mitigating a lack of labeled data is to use data augmentation. Data augmentation introduces a function which produces a stochastic transformation of the input datapoint in such a way that its label remains unchanged. To reiterate, different applications of will produce different (stochastic) outputs. As is typical in many SSL methods, we use data augmentation both on labeled and unlabeled data. For each in the batch of labeled data , we generate a transformed version (algorithm 1, line 3). For each in the batch of unlabeled data , we generate augmentations (algorithm 1, line 5). These individual augmentations are used for generating a “guessed label” for each , through a process we describe in the following section.
For each unlabeled example in , produces a “guess” for the example’s label using the model’s predictions. This guess is later used in the unsupervised loss term. To do so, we compute the average of the model’s predicted classed distributions across all the augmentations of by
in algorithm 1, line 7. Using data augmentation to obtain an artificial target for an unlabeled example is common in consistency regularization methods laine2016temporal; sajjadi2016regularization; tarvainen2017weight.
In generating a label guess, we perform one additional step inspired by the success of entropy minimization in semi-supervised learning (discussed in section 2.2). Given the average prediction over augmentations , we apply a sharpening function to reduce the entropy of the label distribution. In practice, for the sharpening function, we use the common approach of adjusting the “temperature” of this categorical distribution goodfellow2016deep, which is defined as the operation
where is some input categorical distribution (specifically in , is the average class prediction over augmentations , as shown in algorithm 1, line 8) and is a hyperparameter. As , the output of will approach a Dirac (“one-hot”) distribution. Since we will later use as a target for the model’s prediction for an augmentation of , lowering the temperature encourages the model to produce lower-entropy predictions.
As the final step of , we utilize zhang2017mixup. To use for semi-supervised learning, we apply it both to labeled examples and unlabeled examples with label guesses (generated as described in section 3.2). Unlike past work using for SSL verma2018manifold; verma2019interpolation; hataya2019unifying, we “mix” labeled examples with unlabeled examples and vice versa which we find results in improved performance (section 4.2.3
). In our combined loss function (described insection 3.4), we use separate loss terms for labeled and unlabeled data. This causes an issue when using in the originally proposed form; instead, for a pair of two examples with their corresponding (one-hot) labels we define a slightly modified as computing by
where is a hyperparameter. The originally-proposed can be seen as omitting eq. 9 (i.e. setting ). To apply , we first collect all augmented labeled examples and their labels into
and all augmentations of all unlabeled examples with their guessed labels into
(algorithm 1, lines 10–11). Then, we combine these collections and shuffle the result to form which will serve as a data source for (algorithm 1, line 12). For each the example-label pair in , we compute and add the result to the collection (algorithm 1, line 13). Note that because of our slight modification to , the entries in
are guaranteed to be “closer” (in terms of interpolation) to an original labeled datapoint than the corresponding interpolant from. We similarly compute for , intentionally using the remainder of that was not used in the construction of (algorithm 1, line 14). To summarize, transforms into , a collection of labeled examples which have had data augmentation and (potentially mixed with an unlabeled example) applied. Similarly, is transformed into , a collection of multiple augmentations of each unlabeled example with corresponding label guesses.
Given our processed batches and produced by , we use the standard semi-supervised loss shown in eqs. 5, 4 and 3. Equation 5 combines the typical cross-entropy loss between labels and model predictions from with a squared loss on predictions and guessed labels from . The squared loss in eq. 4 corresponds to the multiclass Brier score brier1950verification which, unlike the cross-entropy, is bounded and less sensitive to completely incorrect predictions. As a result, it has frequently been used as a loss for predictions on unlabeled data in semi-supervised learning laine2016temporal; tarvainen2017weight as well as a measure of predictive uncertainty lakshminarayanan2017simple. Note that the guessed labels in eq. 4 are a function of the model parameters; however, as is standard when using this form of loss function laine2016temporal; tarvainen2017weight; miyato2018virtual; oliver2018realistic, we do not propagate gradients through the guessed labels.
Since combines multiple mechanisms for leveraging unlabeled data, it introduces various hyperparameters – specifically, the sharpening temperature , number of unlabeled augmentations , parameter for in , and the unsupervised loss weight . In general, semi-supervised learning methods with many hyperparameters can be problematic to apply in practice due to the difficulty in using cross-validation with small validation sets oliver2018realistic; rasmus2015semi; oliver2018realistic. However, we find in practice that most of ’s hyperparameters can be fixed and do not need to be tuned on a per-experiment or per-dataset basis. Specifically, for all experiments we set and . Further, we only change and on a per-dataset basis; we found that and are good starting points for tuning.
To test the effectiveness of , we apply it to standard semi-supervised learning benchmarks (section 4.2) and provide an extensive ablation study to tease apart the contribution of each of ’s components (section 4.2.3). As an additional application, we consider privacy-preserving learning in section 4.3.
Unless otherwise noted, in all experiments we use the “Wide ResNet-28” model from oliver2018realistic; further details of that model are available in the appendix of oliver2018realistic. Overall, our implementation of the model and training procedure closely matches that of oliver2018realistic, except for the following differences: First, instead of employing a learning rate schedule, we simply evaluate models using an exponential moving average of their parameters with a decay rate of . Second, we utilize weight decay as regularization in all models, decaying weights by at each update for the Wide ResNet-28 model. Finally, we save checkpoint every training samples and simply report the median of the last 20 checkpoints’ error rate. This simplifies the implementation, at a potential cost of an increase in error rate which could be obtained by, for example, averaging checkpoints athiwaratkun2018improving or choosing the checkpoint with the lowest validation error.
First, we evaluate the effectiveness of on four standard benchmark datasets: CIFAR-10 and CIFAR-100 krizhevsky2009learning, SVHN netzer2011reading, and STL-10 coates2011analysis. The first three datasets are common image classification benchmarks for supervised learning; standard practice for evaluating semi-supervised learning on these datasets is to treat most of the dataset as unlabeled and use a small portion (e.g. a few hundred or thousand labels) as labeled data. STL-10 is a dataset specifically designed for SSL, with 5,000 labeled images and 100,000 unlabeled images which are drawn from a slightly different distribution than the labeled data.
As baselines for comparison, we consider the four methods considered in oliver2018realistic (-Model laine2016temporal; sajjadi2016regularization, Mean Teacher tarvainen2017weight, Virtual Adversarial Training miyato2018virtual, and Pseudo-Label lee2013pseudo) which are described in section 2. We also use zhang2017mixup on its own as a baseline. is designed as a regularizer for supervised learning, so we modify it for SSL by applying it both to labeled examples (mixing pairs and from ) and unlabeled examples (mixing pairs of and and using the result as a guessed label). In accordance with standard usage of , we use a cross-entropy loss between the -generated guess label and the model’s prediction. As advocated by oliver2018realistic, we reimplemented each of these methods in the same codebase and applied them to the same model (described in section 4.1) to ensure a fair comparison. We re-tuned the hyperparameters for each baseline method, which generally resulted in a marginal accuracy improvement compared to those in oliver2018realistic, thereby providing a more competitive experimental setting for testing out .
For CIFAR-10, we evaluate the accuracy of each method with a varying number of labeled examples from to (as is standard practice). The results can be seen in fig. 2. We used and
for CIFAR-10. We created 5 splits for each number of labeled points, each with a different random seed. Each model was trained on each split and the error rates were reported by the mean and variance across splits. We find thatoutperforms all other methods by a significant margin, for example reaching an error rate of with labels. For reference, on the same model, fully supervised training on all samples achieves an error rate of . Furthermore, obtains an error rate of with only labels. For comparison, at labels the next-best-performing method (VAT miyato2018virtual) achieves an error rate of , over higher than considering fully supervised error rate as the limit under our model settings. In addition, at labels the next-best-performing method (Mean Teacher tarvainen2017weight) obtains an error rate of , which suggests that can achieve similar performance with only as many labels. We believe that the most interesting comparisons are with very few labeled data points since it reveals the method’s sample efficiency which is central to semi-supervised learning.
Some prior work tarvainen2017weight; athiwaratkun2018improving has also considered the use of a larger, million-parameter model. Our base model, as used in oliver2018realistic, has only million parameters which conflates comparison with these results. For a more reasonable comparison to these results, we measure the effect of increasing the width of our base ResNet model and evaluate ’s performance on a 28-layer Wide Resnet model which has filters per layer, resulting in million parameters. We also evaluate on this larger model on CIFAR-100 with labels, to compare to the corresponding result from athiwaratkun2018improving. The results are shown in table 2. In general, matches or outperforms the best results from athiwaratkun2018improving, though we note that the comparison still remains problematic due to the fact that the model from tarvainen2017weight; athiwaratkun2018improving also uses more sophisticated “shake-shake” regularization gastaldi2017shake. For this model, we used a weight decay of . We used for CIFAR-10 and for CIFAR-100.
|Mean Teacher tarvainen2017weight||-|
As with CIFAR-10, we evaluate the performance of each SSL method on SVHN with a varying number of labels from to . As is standard practice, we first consider the setting where the -example training set is split into labeled and unlabeled data. The results are shown in fig. 3. We used and for SVHN. Here again the models were evaluated on 5 splits for each number of labeled points, each with a different random seed. We found ’s performance to be relatively constant (and better than all other methods) across all amounts of labeled data. Surprisingly, after additional tuning we were able to obtain extremely good performance from Mean Teacher tarvainen2017weight, though its error rate was consistently slightly higher than ’s.
Note that SVHN has two training sets: train and extra. In fully supervised learning, both sets are concatenated to form the full training set ( samples). In semi-supervised learning, for historical reasons unknown to us the extra set was left aside and only train was used ( samples). We argue that the leveraging both train and extra for the unlabeled data is more interesting since it exhibits a higher ratio of unlabeled samples over labeled ones. We report error rates for both SVHN and SVHN+Extra in table table 3. For SVHN+Extra we used and a weight decay of ; as expected with more samples the training required less regularization. We found that on both training sets, nearly matches the fully-supervised performance on the same training set almost immediately – for example, achieves an error rate of with only 250 labels on SVHN+Extra compared to the fully-supervised performance of . Interestingly, on SVHN+Extra outperformed fully supervised training on SVHN without extra ( error) for every labeled data amount considered. To emphasize the importance of this, consider the following scenario: You have examples from SVHN with examples labeled and are given a choice: You can either obtain more unlabeled data and use or obtain more labeled data and use fully-supervised learning. Our results suggest that obtaining additional unlabeled data and using is more effective, which conveniently is likely much cheaper than obtaining more labels.
STL-10 is designed to be used with predefined training set folds with examples each. However, some prior work trains on all examples. We therefore compare in both experimental settings. With examples surpasses both the state-of-the-art for examples as well as the state-of-the-art using all labeled examples. Note that none of the baseline methods in table 2 use the same experimental setup (model architecture, training procedure, etc.) so it is difficult to directly compare the results; however, because obtains the lowest error by a factor of two, we take this to be a vote in confidence of our method. We used and for STL-10.
Since combines various semi-supervised learning mechanisms, it has a good deal in common with existing methods in the literature. As a result, we study the effect of removing or adding components in order to provide additional insight into what makes performant. Specifically, we measure the effect of
using the mean class distribution over augmentations or using the class distribution for a single augmentation (i.e. setting )
removing temperature sharpening (i.e. setting )
using an exponential moving average (EMA) of model parameters when producing guessed labels, as is done by Mean Teacher tarvainen2017weight
performing between labeled examples only, unlabeled examples only, and without mixing across labeled and unlabeled examples
using Interpolation Consistency Training verma2019interpolation, which can be seen as a special case of this ablation study where only unlabeled mixup is used, no sharpening is applied and EMA parameters is used for label guessing.
We carried out the ablation on CIFAR-10 with and labels; the results are shown in table 4. We find that each component contributes to ’s performance, with the most dramatic differences in the -label setting. Despite Mean Teacher’s effectiveness on SVHN (fig. 3), we found that using a similar EMA of parameter values hurt ’s performance slightly.
|without distribution averaging ()|
|without temperature sharpening ()|
|with parameter EMA|
|with on labeled only|
|with on unlabeled only|
|with on separate labeled and unlabeled|
|Interpolation Consistency Training verma2019interpolation|
Learning with privacy is an excellent way to measure our approach’s ability to generalize. Indeed, protecting the privacy of training data amounts to proving that the model does not overfit: a learning algorithm is said to be differentially private111Differential privacy is the most widely accepted technical definition of privacy. if adding, modifying, or removing any of its training samples would not result in a statistically significant difference in the model parameters learned. For this reason, learning with differential privacy is, in practice, a form of regularization.
Each access to the training data constitutes a potential leakage of private information. This sensitive information is often encoded in the pairing between an input and its label. Hence, approaches for deep learning from private training data, such as differentially private SGD(abadi2016deep) but even more so PATE (papernot2016semi), benefit from accessing as few labeled private training points as possible when computing updates to the model parameters. Semi-supervised learning is a natural fit for this setting. We show that MixMatch significantly improves upon the state-of-the-art for learning with differential privacy.
We use the PATE framework for learning with privacy. A student is trained in a semi-supervised way from public unlabeled data, part of which is labeled by an ensemble of teachers with access to private labeled training data. The fewer labels a student requires to reach a fixed accuracy, the stronger is the privacy guarantee it provides. Teachers use a noisy voting mechanism to respond to label queries from the student, and they may choose not to provide a label when they cannot reach a sufficiently strong consensus. For this reason, the fact that MixMatch improves the performance of PATE also illustrates MixMatch’s improved generalization from few canonical exemplars of each class.
We compare the accuracy-privacy trade-off achieved by MixMatch to a VAT miyato2018virtual baseline on SVHN. VAT achieved the previous state-of-the-art of test accuracy for a privacy loss of papernot2018scalable. Because MixMatch performs well with few labeled points, it is able to achieve test accuracy for a much smaller privacy loss of (). Because is used to measure the degree of privacy, the improvement is approximately . A privacy loss below 1 corresponds to a much stronger privacy guarantee. When interpreting the test accuracy, note that the experimental setup used to evaluate PATE as in (papernot2016semi), is different from the rest of this paper because the student trained with MixMatch has access to less training data (no more than 10K points here) than the teachers.
We introduced , a semi-supervised learning method which combines ideas and components from the current dominant paradigms for semi-supervised learning. Through extensive experiments on semi-supervised and privacy-preserving learning, we found that exhibited significantly improved performance compared to other methods in all settings we studied, often by a factor of two or more reduction in error rate. In future work, we are interested in incorporating additional ideas from the semi-supervised learning literature into hybrid methods and continuing to explore which components result in effective algorithms. Separately, most modern work on semi-supervised learning algorithms is evaluated on image benchmarks; we are interested in exploring the effectiveness of in other domains.
We would like to thank Balaji Lakshminarayanan for his helpful theoretical insights.
|Cross-entropy between “target” distribution and “predicted” distribution|
|A labeled example, used as input to a model|
|A (one-hot) label|
|The number of possible label classes (the dimensionality of )|
|A batch of labeled examples and their labels|
|A batch of processed labeled examples produced by|
|An unlabeled example, used as input to a model|
|A guessed label distribution for an unlabeled example|
|A batch of unlabeled examples|
|A batch of processed unlabeled examples with their label guesses produced by|
|The model’s parameters|
|The model’s predicted distribution over classes|
A stochastic data augmentation function that returns a modified version of . For example,
could implement randomly shifting an input image, or implement adding a perturbation sampled from a Gaussian distribution to.
|A hyper-parameter weighting the contribution of the unlabeled examples to the training loss|
|Hyperparameter for the distribution used in|
|Temperature parameter for sharpening used in|
|Number of augmentations used when guessing labels in|
Training the same model with supervised learning on the entire -example training set achieved an error rate of .
Training the same model with supervised learning on the entire -example training set achieved an error rate of .
Training the same model with supervised learning on the entire -example training set achieved an error rate of .