Deep Neural Networks (DNNs) have been successfully used in many real-world applications(Amodei et al, 2016; He et al, 2016). However, the training of DNNs relies on large-scale & high-quality datasets, which has become the core problem in practice. First, large-scale data annotation is highly expensive and only a small amount of labels can be accessed (Li et al, 2019; van Engelen and Hoos, 2020). Second, collected data usually follows long-tailed distribution (He and Garcia, 2009; Liu et al, 2019; Wei and Li, 2019; Wei et al, 2021b, 2022), where only some classes (the majority class) have sufficient training samples while other classes (the minority class) own a few samples as shown in Figure 0(a).
To utilize unlabeled data, Semi-Supervised Learning (SSL) emerges as an interesting solution (Miyato et al, 2018; Tarvainen and Valpola, 2017; Berthelot et al, 2019b; Sohn et al, 2020; Berthelot et al, 2019a; Zhou et al, 2021; Guo et al, 2020). It carries out model assumptions on the data distribution to build a learner to utilize unlabeled samples through selecting confident pseudo-labels. However, it is demonstrated that existing SSL methods tend to produce biased pseudo-labels towards the majority class (Kim et al, 2020), leading to undesirable performance.
Recently, Long-Tailed Semi-Supervised Learning (LTSSL) is proposed to improve the performance of SSL models on long-tailed data. The main ideas of existing LTSSL methods (Kim et al, 2020; Wei et al, 2021a; Lee et al, 2021) are two-fold. One is to improve the quality of pseudo-labels from the perspective of SSL. The other one is to employ class-balanced sampling or post-hoc classifier adjustment to alleviate class imbalance from the long tail perspective. These methods can improve the performance of conventional SSL models. However, the improvements are achieved with the cost of high computational overhead or losing information due to the undersampling of data.
How all data can be efficiently and effectively utilized is the core challenge of LTSSL and the focus of this paper. To this end, we propose a new method called TRAS (TRAnsfer and Share) which has two key ingredients. Figure 0(b) showcases the effectiveness of TRAS in the minority class.
(a) Long-tailed distribution of the training set under the main setting of CIFAR-10-LT. (b) Performance of minority-class accuracy(%) on CIFAR-10-LT dataset under class imbalance ratio 50, 100, and 150 with 20% of labels available. The proposedTRAS heavily improves the minority-class accuracy.
First, we compensate for the minority-class training by generating a more balanced pseudo-label distribution. Under the guidance of pseudo-label distribution, DNNs can mine the interaction information between classes to obtain richer information for minority classes. The idea of learning from label distribution has been explored in previous literature, such as label distribution learning (Geng, 2016; Gao et al, 2017; Wang and Geng, 2019) and knowledge distillation (Xiang et al, 2020; He et al, 2021; Iscen et al, 2021), which however is still underexplored in LTSSL. To generate label distribution, knowledge distillation is a common approach via well-trained teacher models. Such a teacher model is not always available for SSL models because of limited long-tailed labeled data and high computation overhead. Alternatively, we employ a conventional SSL model with normally low accuracy on the minority class. This conventional SSL model is able to teach the learning of the student model after applying our proposed logit transformation. This transformation is particularly designed to enhance the minority-class supervisory signals without introducing extra computational cost. Subsequently, through training the student model by imitating the enhanced supervisory signals, the minority class will receive significant attention.
Second, we propose to merge the training of teacher and student models as a single procedure to reduce the computational cost. To this end, we use a double-branch neural network with a shared feature extractor and two classifiers for producing the predictions of the teacher and student. The neural network is then trained in an end-to-end way by a joint objective of these two classifiers. In addition to reduce training cost and simplify the approach, we empirically find that both classifiers can help improve the representation learning and learn clear classification boundaries between classes.
Our main contributions are summarized as follows:
A new LTSSL method TRAS is proposed, which significantly improves the minority-class training without introducing extra training cost.
TRAS transfers pseudo-label distribution from a vanilla SSL network (teacher) to another network (student) via a new logit transformation, instead of trying hard to construct a sophisticated LTSSL teacher model.
TRAS reveals the importance of the balancedness of pseudo-label distribution in transfer for LTSSL.
TRAS merges the training of teacher and student models by sharing the feature extractor, which simplifies the training procedure and benefits the representation learning.
TRAS achieves state-of-the-art performance in various experiments. Particularly, it improves minority-class performance by about 7% in accuracy.
2 Related Work
Semi-supervised learning. Existing SSL methods aim to use unlabeled data to improve the generalization. For this purpose, consistency regularization and entropy minimization have become the most frequently used techniques and demonstrate considerable performance improvements. Specifically, Mean-Teacher (Tarvainen and Valpola, 2017) imposes consistency regularization between the prediction of the current model and the self-ensembled model obtained using exponential moving average. Virtual Adversarial Training (VAT) (Miyato et al, 2018) encourages the model to minimize the discrepancy of model predictions for unlabeled data before and after applying adversarial perturbation. MixMatch (Berthelot et al, 2019b) minimizes the entropy of model predictions by sharpening the pseudo-label distribution. ReMixMatch (Berthelot et al, 2019a) improves MixMatch by imposing another distribution alignment regularizer and augmentation anchoring. FixMatch (Sohn et al, 2020) merges consistency regularization and entropy minimization by regularizing the prediction for weakly augmented and strongly augmented unlabeled data. However, the above-mentioned methods assume both labeled and unlabeled data is both class-balanced, leading to poor performance on the minority class when working on long-tailed datasets.
Long-tailed semi-supervised learning. To deal with long-tailed datasets, several LTSSL methods have been proposed. In a nutshell, exiting methods aim to select not only confident but also more class-balanced pseudo-labels to improve the generalization for minority classes. For instance, DARP (Kim et al, 2020)
proposes to estimate the underlying class distribution of unlabeled data, which is used to regularize the distribution of pseudo-labels. To this end, a convex optimization problem is solved. Additionally, CReST(Wei et al, 2021a) proposes to use class-aware confidence thresholds for selecting more pseudo-labels for the minority class. Recently, ABC (Lee et al, 2021) proposes to use an auxiliary balanced classifier built upon a conventional SSL model by class-balanced undersampling. However, these approaches either suffer from high computational cost or loss of supervisory information. In this work, we propose a new algorithm TRAS, which can fully utilize not only supervised data but also unsupervised data through efficient pseudo-label distribution transfer, and greatly improves the performance of the minority class.
3 Method: Tras
We now introduce the problem setting in Section 3.1 and develop our proposed method TRAS, which consists of two key ingredients described in Section 3.2 and Section 3.3. Figure 2 shows the framework of the proposed TRAS.
3.1 Problem Setting
Let be a labeled dataset, where denotes the training example and is the corresponding label. We introduce an unlabeled dataset where is the unlabeled data point. Following ABC (Lee et al, 2021), we assume that the class distributions of and are identical. We denote the number of labeled data points of class as (notice that ), assuming that all classes are sorted by cardinality in descending order . Corresponding to LTSSL, we set the ratio of labeled data as and the ratio of the class imbalance as . Following previous LTSSL works, we divide the class space into the majority class and the minority class according to their frequencies in the training data. Finally, our goal is to learn a model which generalizes well on both the majority class and the minority class.
Our proposed method, TRAS, consists of a shared feature extractor and two classifiers, providing predictions for the teacher model and student model . There are two key ingredients to TRAS: (1) Learn through imitation, in which the student model imitates the adjusted output of the teacher model, and (2) transfer via sharing weights. In the following, we present technical details of these two ingredients.
3.2 Ingredient #1: Learn through Imitation
Given labeled data, a typical approach is to train a classifier by optimizing the softmax cross-entropy:
In LTSSL, however, the distribution of labeled data is heavily class-imbalanced, such that the learned classifier would be biased towards the majority class. To improve the training of the minority class, we propose to use the distribution-aware cross-entropy loss:
where is the estimate of class prior and is a scaling parameter. By minimizing , it encourages large margins between the true label and other negative labels. Using distribution-aware cross-entropy is not a new idea in the literature of long-tailed learning, such as Logit Adjustment (Menon et al, 2020) and Balanced Softmax (Ren et al, 2020). Interestingly, existing methods show that the scaling parameter plays an important role in model training, but it is usually used as a constant, e.g., . In the following, we show a new instance-dependent logit scaling method.
In addition to a handful of labeled data, we can access to a large amount of unlabeled data to help improve the generalization. In LTSSL, the underlying distribution of unlabeled data is also long-tailed, and conventional SSL methods have shown impaired performance on the minority class. This paper proposes to train the model using pseudo-label distribution, rather than biased one-hot pseudo-labels. Intuitively, label distribution offers more supervisory signals and can benefit the minority-class training. We generate pseudo-label distribution by first training a vanilla SSL model as the teacher, and then training a student model by imitating the output distribution of teacher model. We opt for minimizing their Kullback-Leibler (KL) divergence:
are output probabilities of the teacher and student model respectively, which illustrate the implicit information of label distribution.
Note that the teacher model is trained via a conventional SSL algorithm and the produced pseudo-label distribution is still biased towards the majority class. To further enhance supervisory signals for the minority class, we present a new logit transformation to adjust the output of the teacher model. Specifically, for sample , we transform its pseudo-label distribution as follows:
where is the output logits and is the pseudo-label of . In this way, the pseudo-label distribution of unlabeled data is more balanced than the original distribution. We demonstrate the generated label distribution in Figure 3.
Notably, different from previous works that treat as a constant to scale the output logits, we use as a function of pseudo-labels. Concretely, given the pseudo-label , we define , where is a -dependent function, and are constants. This is because adjusting pseudo-label distribution to over-compensate the minority class can be harmful to the majority class. By employing the -dependent logit transformation function, we can alleviate this problem by flattening the label distribution of predicted minority-class samples more aggressively than other samples. In experiments, we simply set . Applying the proposed logit transformation generates a more balanced pseudo-label distribution to improve the training of the minority class as in Figure 3.
Putting together the objectives for labeled and unlabeled data, we minimize the loss function forTRAS as follows:
Here, is the indicator function, denotes the confidence threshold and we adopt the common setup for confident pseudo-labels from the student.
In this way, pseudo-label distribution can naturally describe the implicit information between labels. By applying the logit transformation, the distribution encodes more informative supervisory signals for the minority class. By imitating pseudo-label distribution, the student can alleviate data scarcity for minority classes.
3.3 Ingredient #2: Transfer via Sharing Weights
Learning through imitation of the teacher model can significantly compensate for the training of the minority class, however, it needs to train two DNNs models sequentially and is costly in SSL.
To reduce the time consumption and simplify the approach, we propose to merge the training of teacher and student models into a single training procedure. In other words, the teacher and student share the feature extractor network. We further partition the parameter space into three disjoint subsets; (1) Let be a feature extractor for . (2) Let be a teacher classifier producing the prediction . (3) Similarly, let be a student classifier producing the prediction . Subsequently, let us define:
which is the output logits of the teacher model except that its gradient will not be calculated to update the teacher model’s classifier weights. Recall that function acts as a logit transformer of , we then consider:
as the joint objective. Note that the teacher and student share a single feature extractor, it only adds a linear classifier to the conventional SSL model, which incurs negligible training cost.
Let denote the loss for a conventional SSL method, the total loss function that TRAS optimizes is:
Particularly, if FixMatch is employed as the teacher model, consists of a cross-entropy loss on labeled data and a consistency regularization on unlabeled data. Specifically, we have:
where and are the output logits for weak and strong data augmentation, represents the pseudo-label for unlabeled data. In inference, we use the student classifier to predict the label.
3.4 Connection to Previous Work
One may note that the basic idea of our TRAS can transfer knowledge distribution from a vanilla teacher model to a student model that has good generalization for the minority class. The technique is related to knowledge distillation which has been explored in some recent long-tailed learning works. For instance, LFME (Xiang et al, 2020) proposes to train the student model via distilling multiple teachers trained on less imbalanced datasets. DiVE (He et al, 2021) shows that flattening the output distribution of the teacher model using a constant temperature parameter can help the learning of minority classes. CBD (Iscen et al, 2021) distills features from the teacher to the student and shows that it can improve the learned representation of the minority class. Last but not least, xERM (Zhu et al, 2022) obtains an unbiased model by properly adjusting the weights between empirical loss and knowledge distillation loss.
In contrast to previous works that aim to solve supervised long-tailed learning, this paper studies semi-supervised long-tailed learning, where the amount of labeled data is much more limited. Moreover, previous works need to train teacher models via well-established long-tailed learning methods. However, our method TRAS only needs a vanilla SSL model as a teacher. Additionally, these methods have multiple-stage training procedures, but our method is simpler and can be trained in an end-to-end way.
We conduct experiments on long-tailed version of CIFAR-10, CIFAR-100, and SVHN, in comparison with state-of-the-art LTSSL methods. We then perform hyper-parameter sensitivity studies and ablation studies to better understand our proposedTRAS.
4.1 Experimental Setup
Datasets. We conduct experiments on common datasets long-tailed CIFAR-10 (CIFAR-10-LT), long-tailed CIFAR-100 (CIFAR-100-LT) and long-tailed SVHN (SVHN-LT) to evaluate our method. Without loss of generality, for imbalanced SSL settings, we randomly resample the datasets to meet the assumption that the distribution of labeled and unlabeled samples is consistent. We set the ratio of the class imbalance as () and the number of labeled data points of class as , where and for the unlabeled. Specifically, we set , for CIFAR-10-LT and SVHN-LT, , for CIFAR-100-LT respectively.
Following the previous work (Lee et al, 2021), we evaluate the classification performance with imbalance ratio = 100 and 150 for CIFAR-10-LT and SVHN-LT and = 20 and 30 for CIFAR-100-LT. The ratio of labeled data
is 10%, 20% and 30% for CIFAR-10-LT and SVHN-LT, 20%, 40% and 50% for CIFAR-100-LT. Since the test set remains balanced, overall accuracy, minority-class accuracy, and Geometric Mean scores (GM)(Branco et al, 2016) with class-wise sensitivity are three main metrics to validate the proposed method.
Setup. We implement our method with FixMatch over the backbone of Wide ResNet-28-2 (Zagoruyko and Komodakis, 2016). Our method is compared with the supervised baseline, long-tailed supervised learning methods, and long-tailed semi-supervised learning methods, denoted by (a) Vanilla; (b) VAT (Miyato et al, 2018) and FixMatch (Sohn et al, 2020); (c) BALMS (Ren et al, 2020), classifier Re-Training (cRT) (Kang et al, 2020); (d) DARP (Kim et al, 2020), CReST (Wei et al, 2021a), ABC (Lee et al, 2021). We set the hyper-parameters by following FixMatch and train the neural networks for epochs with mini-batches in each epoch, with the batch size of , using Adam optimizer (Kingma and Ba, 2015). The learning rate is with a decay rate of . We start optimizing TRAS after training FixMatch for
epochs. For all experiments, we report the mean and standard deviation of test accuracy over multiple runs.
4.2 Experimental Results
First, the performance of the algorithms compared under the main setting is in Table 1. Results of related methods are borrowed from ABC (Lee et al, 2021). It can been see that our method achieves the best performance, and the improvement on the minority class is impressive. It is known that normal SSL methods such as VAT and FixMatch perform unsatisfactorily on the minority class because pseudo-labels of unlabeled data are affected by the biased model thus hindering the learning of minority classes. Our method significantly improves the performance on the minority class by exploiting knowledge transfer to generate balanced label distribution, which conveys more implicit information than the one-hot pseudo-labels used in most previous LTSSL works. Moreover, our standard deviation is lower than other LTSSL methods, showing the superior stability of TRAS.
To further validate the effectiveness of our method, we report the performance on various settings. The results on CIFAR-10-LT, SVHN-LT and CIFAR-100-LT are reported in Table 2, Table 3 and Table 4. Our method TRAS outperforms other methods in all cases with respect to both overall accuracy and minority-class accuracy. Particularly, TRAS achieves about 10%, 5%, 7% improvements in the minority class on three datasets. Moreover, TRAS is more robust to class imbalance. As the imbalance ratio increases, existing methods severely deteriorate their performance, while the accuracy of our method drops slightly.
To evaluate whether our method TRAS performs balanced prediction for all classes, we measure its performance using Geometric Mean scores (GM) of class-wise accuracy. The results in Table 5 demonstrate that the proposed algorithm yields the best and most balanced performance in all classes. Additionally, TRAS achieves more significant performance improvement on the large dataset (CIFAR-100-LT).
4.3 How Does Pseudo-Label Distribution Impact the Performance?
Recall that we use hyper-parameters and to control the distribution of pseudo-labels by , we now analyze their influence on the performance. The results are reported in Figure 4. We find that impacts the performance much larger than , which coincides with our intuition because . It achieves comparable results by setting . But unlike , yields better performance than other values in our experiments. When setting and , test accuracy severely deteriorates because of the heavy bias towards minority classes in the pseudo-label distribution.
Interestingly, we find the balancedness of pseudo-label distribution matters much more than the accuracy of pseudo-labels in transfer. As increases, the top-5 accuracy of pseudo-labels is impaired while the overall accuracy remains competitive. This indicates that class imbalance hurts the performance more than inaccurate pseudo-labels in our approach.
To better understand this phenomenon, we investigate the impact of logit scaling parameters and on the quality of pseudo-labels for head, torso, and tail classes separately. As illustrated in Figure 5, reveals superb performance with high precision. However, in Figure 6, it shows the worst recall in the tail class. Since means that pseudo-labels are from the conventional SSL model which is biased to the head class, transferring their distribution to a target model does not help the training of tail classes, as shown in Figure 7.
Instead, by setting , it achieves the best performance in overall and tail-class accuracy as reported in Figure 3(a) and Figure 7. Notably, it produces high recall yet low precision for tail classes in Figure 6 and Figure 5. This observation confirms our suspicion that the balancedness of pseudo-label distribution requires more attention than the accuracy of pseudo-labels in knowledge transfer.
4.4 Better Understanding of Tras
We analyze TRAS from representation and classification perspectives on CIFAR-10-LT under the main setting. First, we compare the learned representations by ABC and our TRAS via t-distributed stochastic neighbor embedding (t-SNE) (Van der Maaten and Hinton, 2008) in Figure 8. It can be seen that TRAS has more clear classification boundaries than ABC, which demonstrates that TRAS can distinguish the difference between classes with better representation learning.
Further, to analyze the classification results, we compare the confusion matrices of the prediction on the test set in Figure 9. Each row represents the ground-truth label and each column represents the prediction by ABC or our TRAS. The value in the -th row and -th column is the percentage of samples from the -th class and predicted as the -th class. From the results, we can see that our TRAS performs better than ABC in the minority class. Moreover, it is observed that TRAS might misclassify some majority-class samples as the minority-class ones.
Ablation studies. We conduct ablation studies on important parts of our approach under the main setting.
First, this paper applies logit transformation with to the teacher model’s prediction on unlabeled data for better performance of the teacher model. By removing logit transformation, the overall accuracy and minority-class accuracy under the main setting turn out to be 83.66% (-0.64%) and 77.54% (-4.66%) on CIFAR-10-LT, respectively.
Second, we modify the distribution-aware cross-entropy for the labeled data to the common cross-entropy loss, leading to 83.41% (-0.89%) and 80.28% (-1.92%) of the overall accuracy and minority-class accuracy. The marginal decline of the performance verifies the effectiveness of the learning through imitation approach.
Finally, we remove the sample mask on unlabeled data of the student model, which means all unlabeled data is used to imitate the teacher. The experiment shows that removing the sample mask decreases the performance slightly, i.e., 83.33% (-0.97%) and 79.60% (-2.59%) for overall and minority-class accuracy respectively. This demonstrates the advantage of selecting more accurate pseudo-labels for the student model.
4.5 Comparison with Two-stage Training
We further compare TRAS with two-stage training under the main setting in Table 6. In the two-stage training, we first train a FixMatch model as the teacher and then guide a student model by the teacher with our TRAS. We find that not only can our TRAS save training cost, but it is also better trained than the two-stage approach. This agrees with our expectation that, in two-stage training, the student lays more emphasis on the minority class, while fitting the pseudo-label distribution does not necessarily improve feature learning. Fortunately, two branches in TRAS share the feature learning backbone, which can improve the backbone and classifiers simultaneously.
To further show that double branches of TRAS can both improve feature learning, we stop propagating the gradients of the student branch from affecting the feature learning backbone. In this way, only the teacher model trains the feature extractor network, which is similar to the two-stage training. From the result, we see TRAS performs better, showing that the student can further enhance the feature learning by sharing the backbone with the teacher model.
|TRAS−111The student does not propagate its gradients to update the feature extractor.||83.2/80.8||92.6/92.1||57.8/50.3|
We introduce TRAS, a new method for LTSSL. TRAS (1) learns from a more class-balanced label distribution to improve the minority-class generalization and (2) partitions the parameter space, enabling transfer via weight sharing of the transformed knowledge learned by the conventional SSL model. Extensive experiments on CIFAR-10-LT, SVHN-LT, and CIFAR-100-LT datasets show that TRAS outperforms state-of-the-art methods on the minority class by a large margin. In the sequel, it would be interesting to extend TRAS to more established SSL methods.
Amodei D, Ananthanarayanan S, Anubhai R, et al (2016) Deep speech 2: End-to-end speech recognition in english and mandarin. In: International Conference on Machine Learning, pp 173–182
Li Y, Wang H, Wei T, et al (2019) Towards automated semi-supervised learning. In: Proceedings of the AAAI Conference on Artificial Intelligence, pp 4237–4244
Tarvainen A, Valpola H (2017) Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. Advances in Neural Information Processing Systems 30:1195–1204