DeepAI
Log In Sign Up

Transfer and Share: Semi-Supervised Learning from Long-Tailed Data

05/26/2022
by   Tong Wei, et al.
Nanjing University
6

Long-Tailed Semi-Supervised Learning (LTSSL) aims to learn from class-imbalanced data where only a few samples are annotated. Existing solutions typically require substantial cost to solve complex optimization problems, or class-balanced undersampling which can result in information loss. In this paper, we present the TRAS (TRAnsfer and Share) to effectively utilize long-tailed semi-supervised data. TRAS transforms the imbalanced pseudo-label distribution of a traditional SSL model via a delicate function to enhance the supervisory signals for minority classes. It then transfers the distribution to a target model such that the minority class will receive significant attention. Interestingly, TRAS shows that more balanced pseudo-label distribution can substantially benefit minority-class training, instead of seeking to generate accurate pseudo-labels as in previous works. To simplify the approach, TRAS merges the training of the traditional SSL model and the target model into a single procedure by sharing the feature extractor, where both classifiers help improve the representation learning. According to extensive experiments, TRAS delivers much higher accuracy than state-of-the-art methods in the entire set of classes as well as minority classes.

READ FULL TEXT VIEW PDF

page 13

page 15

02/18/2021

CReST: A Class-Rebalancing Self-Training Framework for Imbalanced Semi-Supervised Learning

Semi-supervised learning on class-imbalanced data, although a realistic ...
06/10/2021

Distribution-Aware Semantics-Oriented Pseudo-label for Imbalanced Semi-Supervised Learning

The capability of the traditional semi-supervised learning (SSL) methods...
05/01/2021

Semi-supervised Long-tailed Recognition using Alternate Sampling

Main challenges in long-tailed recognition come from the imbalanced data...
06/01/2021

Rethinking Re-Sampling in Imbalanced Semi-Supervised Learning

Semi-Supervised Learning (SSL) has shown its strong ability in utilizing...
01/05/2022

Debiased Learning from Naturally Imbalanced Pseudo-Labels for Zero-Shot and Semi-Supervised Learning

This work studies the bias issue of pseudo-labeling, a natural phenomeno...
02/10/2020

Semi-Supervised Class Discovery

One promising approach to dealing with datapoints that are outside of th...
08/31/2022

Addressing Class Imbalance in Semi-supervised Image Segmentation: A Study on Cardiac MRI

Due to the imbalanced and limited data, semi-supervised medical image se...

1 Introduction

Deep Neural Networks (DNNs) have been successfully used in many real-world applications

(Amodei et al, 2016; He et al, 2016). However, the training of DNNs relies on large-scale & high-quality datasets, which has become the core problem in practice. First, large-scale data annotation is highly expensive and only a small amount of labels can be accessed (Li et al, 2019; van Engelen and Hoos, 2020). Second, collected data usually follows long-tailed distribution (He and Garcia, 2009; Liu et al, 2019; Wei and Li, 2019; Wei et al, 2021b, 2022), where only some classes (the majority class) have sufficient training samples while other classes (the minority class) own a few samples as shown in Figure 0(a).

To utilize unlabeled data, Semi-Supervised Learning (SSL) emerges as an interesting solution (Miyato et al, 2018; Tarvainen and Valpola, 2017; Berthelot et al, 2019b; Sohn et al, 2020; Berthelot et al, 2019a; Zhou et al, 2021; Guo et al, 2020). It carries out model assumptions on the data distribution to build a learner to utilize unlabeled samples through selecting confident pseudo-labels. However, it is demonstrated that existing SSL methods tend to produce biased pseudo-labels towards the majority class (Kim et al, 2020), leading to undesirable performance.

Recently, Long-Tailed Semi-Supervised Learning (LTSSL) is proposed to improve the performance of SSL models on long-tailed data. The main ideas of existing LTSSL methods (Kim et al, 2020; Wei et al, 2021a; Lee et al, 2021) are two-fold. One is to improve the quality of pseudo-labels from the perspective of SSL. The other one is to employ class-balanced sampling or post-hoc classifier adjustment to alleviate class imbalance from the long tail perspective. These methods can improve the performance of conventional SSL models. However, the improvements are achieved with the cost of high computational overhead or losing information due to the undersampling of data.

How all data can be efficiently and effectively utilized is the core challenge of LTSSL and the focus of this paper. To this end, we propose a new method called TRAS (TRAnsfer and Share) which has two key ingredients. Figure 0(b) showcases the effectiveness of TRAS in the minority class.

(a) Long-tailed distribution
(b) Minority-class accuracy
Figure 1:

(a) Long-tailed distribution of the training set under the main setting of CIFAR-10-LT. (b) Performance of minority-class accuracy(%) on CIFAR-10-LT dataset under class imbalance ratio 50, 100, and 150 with 20% of labels available. The proposed

TRAS heavily improves the minority-class accuracy.

First, we compensate for the minority-class training by generating a more balanced pseudo-label distribution. Under the guidance of pseudo-label distribution, DNNs can mine the interaction information between classes to obtain richer information for minority classes. The idea of learning from label distribution has been explored in previous literature, such as label distribution learning (Geng, 2016; Gao et al, 2017; Wang and Geng, 2019) and knowledge distillation (Xiang et al, 2020; He et al, 2021; Iscen et al, 2021), which however is still underexplored in LTSSL. To generate label distribution, knowledge distillation is a common approach via well-trained teacher models. Such a teacher model is not always available for SSL models because of limited long-tailed labeled data and high computation overhead. Alternatively, we employ a conventional SSL model with normally low accuracy on the minority class. This conventional SSL model is able to teach the learning of the student model after applying our proposed logit transformation. This transformation is particularly designed to enhance the minority-class supervisory signals without introducing extra computational cost. Subsequently, through training the student model by imitating the enhanced supervisory signals, the minority class will receive significant attention.

Second, we propose to merge the training of teacher and student models as a single procedure to reduce the computational cost. To this end, we use a double-branch neural network with a shared feature extractor and two classifiers for producing the predictions of the teacher and student. The neural network is then trained in an end-to-end way by a joint objective of these two classifiers. In addition to reduce training cost and simplify the approach, we empirically find that both classifiers can help improve the representation learning and learn clear classification boundaries between classes.

Our main contributions are summarized as follows:

  1. A new LTSSL method TRAS is proposed, which significantly improves the minority-class training without introducing extra training cost.

  2. TRAS transfers pseudo-label distribution from a vanilla SSL network (teacher) to another network (student) via a new logit transformation, instead of trying hard to construct a sophisticated LTSSL teacher model.

  3. TRAS reveals the importance of the balancedness of pseudo-label distribution in transfer for LTSSL.

  4. TRAS merges the training of teacher and student models by sharing the feature extractor, which simplifies the training procedure and benefits the representation learning.

  5. TRAS achieves state-of-the-art performance in various experiments. Particularly, it improves minority-class performance by about 7% in accuracy.

2 Related Work

Semi-supervised learning. Existing SSL methods aim to use unlabeled data to improve the generalization. For this purpose, consistency regularization and entropy minimization have become the most frequently used techniques and demonstrate considerable performance improvements. Specifically, Mean-Teacher (Tarvainen and Valpola, 2017) imposes consistency regularization between the prediction of the current model and the self-ensembled model obtained using exponential moving average. Virtual Adversarial Training (VAT) (Miyato et al, 2018) encourages the model to minimize the discrepancy of model predictions for unlabeled data before and after applying adversarial perturbation. MixMatch (Berthelot et al, 2019b) minimizes the entropy of model predictions by sharpening the pseudo-label distribution. ReMixMatch (Berthelot et al, 2019a) improves MixMatch by imposing another distribution alignment regularizer and augmentation anchoring. FixMatch (Sohn et al, 2020) merges consistency regularization and entropy minimization by regularizing the prediction for weakly augmented and strongly augmented unlabeled data. However, the above-mentioned methods assume both labeled and unlabeled data is both class-balanced, leading to poor performance on the minority class when working on long-tailed datasets.

Long-tailed semi-supervised learning. To deal with long-tailed datasets, several LTSSL methods have been proposed. In a nutshell, exiting methods aim to select not only confident but also more class-balanced pseudo-labels to improve the generalization for minority classes. For instance, DARP (Kim et al, 2020)

proposes to estimate the underlying class distribution of unlabeled data, which is used to regularize the distribution of pseudo-labels. To this end, a convex optimization problem is solved. Additionally, CReST

(Wei et al, 2021a) proposes to use class-aware confidence thresholds for selecting more pseudo-labels for the minority class. Recently, ABC (Lee et al, 2021) proposes to use an auxiliary balanced classifier built upon a conventional SSL model by class-balanced undersampling. However, these approaches either suffer from high computational cost or loss of supervisory information. In this work, we propose a new algorithm TRAS, which can fully utilize not only supervised data but also unsupervised data through efficient pseudo-label distribution transfer, and greatly improves the performance of the minority class.

3 Method: Tras

We now introduce the problem setting in Section 3.1 and develop our proposed method TRAS, which consists of two key ingredients described in Section 3.2 and Section 3.3. Figure 2 shows the framework of the proposed TRAS.

3.1 Problem Setting

Let be a labeled dataset, where denotes the training example and is the corresponding label. We introduce an unlabeled dataset where is the unlabeled data point. Following ABC (Lee et al, 2021), we assume that the class distributions of and are identical. We denote the number of labeled data points of class as (notice that ), assuming that all classes are sorted by cardinality in descending order . Corresponding to LTSSL, we set the ratio of labeled data as and the ratio of the class imbalance as . Following previous LTSSL works, we divide the class space into the majority class and the minority class according to their frequencies in the training data. Finally, our goal is to learn a model which generalizes well on both the majority class and the minority class.

Our proposed method, TRAS, consists of a shared feature extractor and two classifiers, providing predictions for the teacher model and student model . There are two key ingredients to TRAS: (1) Learn through imitation, in which the student model imitates the adjusted output of the teacher model, and (2) transfer via sharing weights. In the following, we present technical details of these two ingredients.

Figure 2: The TRAS method in diagrammatic form.

3.2 Ingredient #1: Learn through Imitation

Given labeled data, a typical approach is to train a classifier by optimizing the softmax cross-entropy:

(1)

In LTSSL, however, the distribution of labeled data is heavily class-imbalanced, such that the learned classifier would be biased towards the majority class. To improve the training of the minority class, we propose to use the distribution-aware cross-entropy loss:

(2)

where is the estimate of class prior and is a scaling parameter. By minimizing , it encourages large margins between the true label and other negative labels. Using distribution-aware cross-entropy is not a new idea in the literature of long-tailed learning, such as Logit Adjustment (Menon et al, 2020) and Balanced Softmax (Ren et al, 2020). Interestingly, existing methods show that the scaling parameter plays an important role in model training, but it is usually used as a constant, e.g., . In the following, we show a new instance-dependent logit scaling method.

In addition to a handful of labeled data, we can access to a large amount of unlabeled data to help improve the generalization. In LTSSL, the underlying distribution of unlabeled data is also long-tailed, and conventional SSL methods have shown impaired performance on the minority class. This paper proposes to train the model using pseudo-label distribution, rather than biased one-hot pseudo-labels. Intuitively, label distribution offers more supervisory signals and can benefit the minority-class training. We generate pseudo-label distribution by first training a vanilla SSL model as the teacher, and then training a student model by imitating the output distribution of teacher model. We opt for minimizing their Kullback-Leibler (KL) divergence:

(3)

where and

are output probabilities of the teacher and student model respectively, which illustrate the implicit information of label distribution.

Note that the teacher model is trained via a conventional SSL algorithm and the produced pseudo-label distribution is still biased towards the majority class. To further enhance supervisory signals for the minority class, we present a new logit transformation to adjust the output of the teacher model. Specifically, for sample , we transform its pseudo-label distribution as follows:

(4)

where is the output logits and is the pseudo-label of . In this way, the pseudo-label distribution of unlabeled data is more balanced than the original distribution. We demonstrate the generated label distribution in Figure 3.

Notably, different from previous works that treat as a constant to scale the output logits, we use as a function of pseudo-labels. Concretely, given the pseudo-label , we define , where is a -dependent function, and are constants. This is because adjusting pseudo-label distribution to over-compensate the minority class can be harmful to the majority class. By employing the -dependent logit transformation function, we can alleviate this problem by flattening the label distribution of predicted minority-class samples more aggressively than other samples. In experiments, we simply set . Applying the proposed logit transformation generates a more balanced pseudo-label distribution to improve the training of the minority class as in Figure 3.

Figure 3: Comparison of ground-truth label distribution and our generated pseudo-label distribution on CIFAR-100-LT dataset under class imbalance ratio 20 with 40% of labels available.

Putting together the objectives for labeled and unlabeled data, we minimize the loss function for

TRAS as follows:

(5)

Here, is the indicator function, denotes the confidence threshold and we adopt the common setup for confident pseudo-labels from the student.

In this way, pseudo-label distribution can naturally describe the implicit information between labels. By applying the logit transformation, the distribution encodes more informative supervisory signals for the minority class. By imitating pseudo-label distribution, the student can alleviate data scarcity for minority classes.

3.3 Ingredient #2: Transfer via Sharing Weights

Learning through imitation of the teacher model can significantly compensate for the training of the minority class, however, it needs to train two DNNs models sequentially and is costly in SSL.

To reduce the time consumption and simplify the approach, we propose to merge the training of teacher and student models into a single training procedure. In other words, the teacher and student share the feature extractor network. We further partition the parameter space into three disjoint subsets; (1) Let be a feature extractor for . (2) Let be a teacher classifier producing the prediction . (3) Similarly, let be a student classifier producing the prediction . Subsequently, let us define:

(6)

which is the output logits of the teacher model except that its gradient will not be calculated to update the teacher model’s classifier weights. Recall that function acts as a logit transformer of , we then consider:

(7)

as the joint objective. Note that the teacher and student share a single feature extractor, it only adds a linear classifier to the conventional SSL model, which incurs negligible training cost.

Let denote the loss for a conventional SSL method, the total loss function that TRAS optimizes is:

(8)

Particularly, if FixMatch is employed as the teacher model, consists of a cross-entropy loss on labeled data and a consistency regularization on unlabeled data. Specifically, we have:

(9)

where and are the output logits for weak and strong data augmentation, represents the pseudo-label for unlabeled data. In inference, we use the student classifier to predict the label.

3.4 Connection to Previous Work

One may note that the basic idea of our TRAS can transfer knowledge distribution from a vanilla teacher model to a student model that has good generalization for the minority class. The technique is related to knowledge distillation which has been explored in some recent long-tailed learning works. For instance, LFME (Xiang et al, 2020) proposes to train the student model via distilling multiple teachers trained on less imbalanced datasets. DiVE (He et al, 2021) shows that flattening the output distribution of the teacher model using a constant temperature parameter can help the learning of minority classes. CBD (Iscen et al, 2021) distills features from the teacher to the student and shows that it can improve the learned representation of the minority class. Last but not least, xERM (Zhu et al, 2022) obtains an unbiased model by properly adjusting the weights between empirical loss and knowledge distillation loss.

In contrast to previous works that aim to solve supervised long-tailed learning, this paper studies semi-supervised long-tailed learning, where the amount of labeled data is much more limited. Moreover, previous works need to train teacher models via well-established long-tailed learning methods. However, our method TRAS only needs a vanilla SSL model as a teacher. Additionally, these methods have multiple-stage training procedures, but our method is simpler and can be trained in an end-to-end way.

4 Experiments

We conduct experiments on long-tailed version of CIFAR-10, CIFAR-100, and SVHN, in comparison with state-of-the-art LTSSL methods. We then perform hyper-parameter sensitivity studies and ablation studies to better understand our proposed

TRAS.

4.1 Experimental Setup

Datasets. We conduct experiments on common datasets long-tailed CIFAR-10 (CIFAR-10-LT), long-tailed CIFAR-100 (CIFAR-100-LT) and long-tailed SVHN (SVHN-LT) to evaluate our method. Without loss of generality, for imbalanced SSL settings, we randomly resample the datasets to meet the assumption that the distribution of labeled and unlabeled samples is consistent. We set the ratio of the class imbalance as () and the number of labeled data points of class as , where and for the unlabeled. Specifically, we set , for CIFAR-10-LT and SVHN-LT, , for CIFAR-100-LT respectively.

Following the previous work (Lee et al, 2021), we evaluate the classification performance with imbalance ratio = 100 and 150 for CIFAR-10-LT and SVHN-LT and = 20 and 30 for CIFAR-100-LT. The ratio of labeled data

is 10%, 20% and 30% for CIFAR-10-LT and SVHN-LT, 20%, 40% and 50% for CIFAR-100-LT. Since the test set remains balanced, overall accuracy, minority-class accuracy, and Geometric Mean scores (GM)

(Branco et al, 2016) with class-wise sensitivity are three main metrics to validate the proposed method.

Setup. We implement our method with FixMatch over the backbone of Wide ResNet-28-2 (Zagoruyko and Komodakis, 2016). Our method is compared with the supervised baseline, long-tailed supervised learning methods, and long-tailed semi-supervised learning methods, denoted by (a) Vanilla; (b) VAT (Miyato et al, 2018) and FixMatch (Sohn et al, 2020); (c) BALMS (Ren et al, 2020), classifier Re-Training (cRT) (Kang et al, 2020); (d) DARP (Kim et al, 2020), CReST (Wei et al, 2021a), ABC (Lee et al, 2021). We set the hyper-parameters by following FixMatch and train the neural networks for epochs with mini-batches in each epoch, with the batch size of , using Adam optimizer (Kingma and Ba, 2015). The learning rate is with a decay rate of . We start optimizing TRAS after training FixMatch for

epochs. For all experiments, we report the mean and standard deviation of test accuracy over multiple runs.

4.2 Experimental Results

First, the performance of the algorithms compared under the main setting is in Table 1. Results of related methods are borrowed from ABC (Lee et al, 2021). It can been see that our method achieves the best performance, and the improvement on the minority class is impressive. It is known that normal SSL methods such as VAT and FixMatch perform unsatisfactorily on the minority class because pseudo-labels of unlabeled data are affected by the biased model thus hindering the learning of minority classes. Our method significantly improves the performance on the minority class by exploiting knowledge transfer to generate balanced label distribution, which conveys more implicit information than the one-hot pseudo-labels used in most previous LTSSL works. Moreover, our standard deviation is lower than other LTSSL methods, showing the superior stability of TRAS.

CIFAR-10-LT SVHN-LT CIFAR-100-LT
Algorithm
Vanilla 55.3/33.9 77.0/63.3 40.1/25.2
VAT 55.3/28.2 81.3/68.2 40.4/24.8
BALMS 70.7/69.8 87.6/85.0 50.2/42.9
FixMatch 72.3/53.8 88.0/79.4 51.0/32.8
w/ CReST+PDA 76.6/61.4 89.1/81.7 51.6/36.4
w/ DARP 73.7/57.0 88.6/80.5 51.4/33.9
w/ DARP+cRT 78.1/66.6 89.9/83.5 54.7/41.2
w/ ABC 81.1/72.0 92.0/87.9 56.3/43.4
w/ TRAS(ours) 84.3/82.2 93.4/92.5 58.5/50.3
Table 1: Overall accuracy(%)/minority-class accuracy(%) under the main setting.

To further validate the effectiveness of our method, we report the performance on various settings. The results on CIFAR-10-LT, SVHN-LT and CIFAR-100-LT are reported in Table 2, Table 3 and Table 4. Our method TRAS outperforms other methods in all cases with respect to both overall accuracy and minority-class accuracy. Particularly, TRAS achieves about 10%, 5%, 7% improvements in the minority class on three datasets. Moreover, TRAS is more robust to class imbalance. As the imbalance ratio increases, existing methods severely deteriorate their performance, while the accuracy of our method drops slightly.

CIFAR-10-LT
Algorithm
FixMatch 70.0/48.9 74.9/58.2 68.5/45.8
w/ CReST+PDA 73.9/58.9 77.6/64.0 70.0/49.4
w/ DARP+cRT 74.6/59.2 79.0/67.7 73.2/57.1
w/ ABC 77.2/65.7 81.5/72.9 77.1/64.4
w/ TRAS(ours) 82.1/78.6 85.0/83.0 81.7/77.2
Table 2: Overall accuracy(%)/minority-class accuracy(%) for CIFAR-10-LT. Two imbalance ratios and three labeled data ratios are evaluated.
SVHN-LT
Algorithm
FixMatch 88.5/80.3 88.7/80.7 85.6/74.6
w/ CReST+PDA 89.2/81.7 89.9/83.0 86.7/76.7
w/ DARP+cRT 89.3/83.9 90.7/84.8 88.0/80.1
w/ ABC 92.3/88.7 92.3/88.3 91.2/86.2
w/ TRAS(ours) 93.2/92.5 93.9/93.4 92.1/91.1
Table 3: Overall accuracy(%)/minority-class accuracy(%) on SVHN-LT. Two imbalance ratios and three labeled data ratios are evaluated.
CIFAR-100-LT
Algorithm
FixMatch 46.1/26.6 52.3/34.7 47.6/27.6
w/ CReST+PDA 46.7/29.3 52.7/37.4 48.5/30.0
w/ DARP+cRT 48.9/33.5 55.9/43.5 51.3/36.4
w/ ABC 49.7/34.6 58.3/46.7 53.6/38.8
w/ TRAS(ours) 51.6/41.8 60.3/53.5 55.5/46.5
Table 4: Overall accuracy(%)/minority-class accuracy(%) on CIFAR-100-LT. Two imbalance ratios and three labeled data ratios are evaluated.

To evaluate whether our method TRAS performs balanced prediction for all classes, we measure its performance using Geometric Mean scores (GM) of class-wise accuracy. The results in Table 5 demonstrate that the proposed algorithm yields the best and most balanced performance in all classes. Additionally, TRAS achieves more significant performance improvement on the large dataset (CIFAR-100-LT).

CIFAR-10-LT SVHN-LT CIFAR-100-LT
Algorithm
FixMatch 62.0 87.3 38.5
w/ CReST+PDA 74.4 88.6 42.3
w/ DARP 71.5 87.6 40.4
w/ DARP+cRT 76.7 89.8 47.0
w/ ABC 80.5 91.8 49.0
w/ TRAS(ours) 81.9 93.4 54.0
Table 5: Results of GM(%) under the main setting.

4.3 How Does Pseudo-Label Distribution Impact the Performance?

Recall that we use hyper-parameters and to control the distribution of pseudo-labels by , we now analyze their influence on the performance. The results are reported in Figure 4. We find that impacts the performance much larger than , which coincides with our intuition because . It achieves comparable results by setting . But unlike , yields better performance than other values in our experiments. When setting and , test accuracy severely deteriorates because of the heavy bias towards minority classes in the pseudo-label distribution.

(a) Overall accuracy
(b) Top-5 accuracy of pseudo-labels
Figure 4: The impact of values of and on CIFAR-10-LT under class imbalance ratio 100 with 20% of labels available. (a) Overall accuracy(%) of TRAS; (b) Top-5 accuracy(%) of pseudo-labels of the teacher after transformation.

Interestingly, we find the balancedness of pseudo-label distribution matters much more than the accuracy of pseudo-labels in transfer. As increases, the top-5 accuracy of pseudo-labels is impaired while the overall accuracy remains competitive. This indicates that class imbalance hurts the performance more than inaccurate pseudo-labels in our approach.

To better understand this phenomenon, we investigate the impact of logit scaling parameters and on the quality of pseudo-labels for head, torso, and tail classes separately. As illustrated in Figure 5, reveals superb performance with high precision. However, in Figure 6, it shows the worst recall in the tail class. Since means that pseudo-labels are from the conventional SSL model which is biased to the head class, transferring their distribution to a target model does not help the training of tail classes, as shown in Figure 7.

Instead, by setting , it achieves the best performance in overall and tail-class accuracy as reported in Figure 3(a) and Figure 7. Notably, it produces high recall yet low precision for tail classes in Figure 6 and Figure 5. This observation confirms our suspicion that the balancedness of pseudo-label distribution requires more attention than the accuracy of pseudo-labels in knowledge transfer.

Figure 5: Comparison of pseudo-label precision by varying the values of A and B on CIFAR-10-LT under class imbalance ratio 100 with 20% of labels available. The x-axis is the number of epochs, and the y-axis is the precision. Classes are divided into head ({0, 1, 2}), torso ({3, 4, 5, 6}) and tail ({7, 8, 9}).
Figure 6: Comparison of pseudo-label recall by varying the value of A and B on CIFAR-10-LT under class imbalance ratio 100 with 20% of labels. The x-axis is the number of epochs, and the y-axis is the recall.
Figure 7: Comparison of test accuracy by varying the values of A and B on CIFAR-10-LT under class imbalance ratio 100 with 20% of labels. The x-axis is the number of epochs, and the y-axis is the test accuracy.

4.4 Better Understanding of Tras

We analyze TRAS from representation and classification perspectives on CIFAR-10-LT under the main setting. First, we compare the learned representations by ABC and our TRAS via t-distributed stochastic neighbor embedding (t-SNE) (Van der Maaten and Hinton, 2008) in Figure 8. It can be seen that TRAS has more clear classification boundaries than ABC, which demonstrates that TRAS can distinguish the difference between classes with better representation learning.

(a) ABC
(b) TRAS
Figure 8: Results of t-SNE for ABC (Lee et al, 2021) and our TRAS.

Further, to analyze the classification results, we compare the confusion matrices of the prediction on the test set in Figure 9. Each row represents the ground-truth label and each column represents the prediction by ABC or our TRAS. The value in the -th row and -th column is the percentage of samples from the -th class and predicted as the -th class. From the results, we can see that our TRAS performs better than ABC in the minority class. Moreover, it is observed that TRAS might misclassify some majority-class samples as the minority-class ones.

(a) ABC
(b) TRAS
Figure 9: Confusion matrices of the prediction on the test set of CIFAR-10-LT.

Ablation studies. We conduct ablation studies on important parts of our approach under the main setting.

First, this paper applies logit transformation with to the teacher model’s prediction on unlabeled data for better performance of the teacher model. By removing logit transformation, the overall accuracy and minority-class accuracy under the main setting turn out to be 83.66% (-0.64%) and 77.54% (-4.66%) on CIFAR-10-LT, respectively.

Second, we modify the distribution-aware cross-entropy for the labeled data to the common cross-entropy loss, leading to 83.41% (-0.89%) and 80.28% (-1.92%) of the overall accuracy and minority-class accuracy. The marginal decline of the performance verifies the effectiveness of the learning through imitation approach.

Finally, we remove the sample mask on unlabeled data of the student model, which means all unlabeled data is used to imitate the teacher. The experiment shows that removing the sample mask decreases the performance slightly, i.e., 83.33% (-0.97%) and 79.60% (-2.59%) for overall and minority-class accuracy respectively. This demonstrates the advantage of selecting more accurate pseudo-labels for the student model.

4.5 Comparison with Two-stage Training

We further compare TRAS with two-stage training under the main setting in Table 6. In the two-stage training, we first train a FixMatch model as the teacher and then guide a student model by the teacher with our TRAS. We find that not only can our TRAS save training cost, but it is also better trained than the two-stage approach. This agrees with our expectation that, in two-stage training, the student lays more emphasis on the minority class, while fitting the pseudo-label distribution does not necessarily improve feature learning. Fortunately, two branches in TRAS share the feature learning backbone, which can improve the backbone and classifiers simultaneously.

To further show that double branches of TRAS can both improve feature learning, we stop propagating the gradients of the student branch from affecting the feature learning backbone. In this way, only the teacher model trains the feature extractor network, which is similar to the two-stage training. From the result, we see TRAS performs better, showing that the student can further enhance the feature learning by sharing the backbone with the teacher model.

CIFAR-10-LT SVHN-LT CIFAR-100-LT
Method
Two-stage 80.6/75.0 91.6/89.6 57.1/47.6
TRAS111The student does not propagate its gradients to update the feature extractor. 83.2/80.8 92.6/92.1 57.8/50.3
TRAS 84.3/82.2 93.4/92.5 58.5/50.3
Table 6: Performance comparison of overall accuracy(%)/minority-class accuracy(%) with two-stage training and TRAS−.

5 Conclusion

We introduce TRAS, a new method for LTSSL. TRAS (1) learns from a more class-balanced label distribution to improve the minority-class generalization and (2) partitions the parameter space, enabling transfer via weight sharing of the transformed knowledge learned by the conventional SSL model. Extensive experiments on CIFAR-10-LT, SVHN-LT, and CIFAR-100-LT datasets show that TRAS outperforms state-of-the-art methods on the minority class by a large margin. In the sequel, it would be interesting to extend TRAS to more established SSL methods.

References

  • Amodei et al (2016)

    Amodei D, Ananthanarayanan S, Anubhai R, et al (2016) Deep speech 2: End-to-end speech recognition in english and mandarin. In: International Conference on Machine Learning, pp 173–182

  • Berthelot et al (2019a) Berthelot D, Carlini N, Cubuk ED, et al (2019a) Remixmatch: Semi-supervised learning with distribution matching and augmentation anchoring. In: International Conference on Learning Representations
  • Berthelot et al (2019b) Berthelot D, Carlini N, Goodfellow I, et al (2019b) Mixmatch: A holistic approach to semi-supervised learning. Advances in Neural Information Processing Systems 32:5050–5060
  • Branco et al (2016) Branco P, Torgo L, Ribeiro RP (2016) A survey of predictive modeling on imbalanced domains. ACM Computing Surveys 49(2):1–50
  • van Engelen and Hoos (2020) van Engelen JE, Hoos HH (2020) A survey on semi-supervised learning. Machine Learning 109(2):373–440
  • Gao et al (2017) Gao BB, Xing C, Xie CW, et al (2017) Deep label distribution learning with label ambiguity. IEEE Transactions on Image Processing 26(6):2825–2838
  • Geng (2016) Geng X (2016) Label distribution learning. IEEE Transactions on Knowledge and Data Engineering 28(7):1734–1748
  • Guo et al (2020) Guo LZ, Zhang ZY, Jiang Y, et al (2020) Safe deep semi-supervised learning for unseen-class unlabeled data. In: International Conference on Machine Learning, pp 3897–3906
  • He and Garcia (2009) He H, Garcia EA (2009) Learning from imbalanced data. IEEE Transactions on Knowledge and Data Engineering 21(9):1263–1284
  • He et al (2016)

    He K, Zhang X, Ren S, et al (2016) Deep residual learning for image recognition. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp 770–778

  • He et al (2021) He YY, Wu J, Wei XS (2021) Distilling virtual examples for long-tailed recognition. In: Proceedings of the IEEE/CVF International Conference on Computer Vision, pp 235–244
  • Iscen et al (2021) Iscen A, Araujo A, Gong B, et al (2021) Class-balanced distillation for long-tailed visual recognition. In: The British Machine Vision Conference
  • Kang et al (2020) Kang B, Xie S, Rohrbach M, et al (2020) Decoupling representation and classifier for long-tailed recognition. In: International Conference on Learning Representations
  • Kim et al (2020) Kim J, Hur Y, Park S, et al (2020) Distribution aligning refinery of pseudo-label for imbalanced semi-supervised learning. Advances in Neural Information Processing Systems 33:14,567–14,579
  • Kingma and Ba (2015) Kingma DP, Ba J (2015) Adam: A method for stochastic optimization. In: International Conference on Learning Representations
  • Lee et al (2021) Lee H, Shin S, Kim H (2021) Abc: Auxiliary balanced classifier for class-imbalanced semi-supervised learning. Advances in Neural Information Processing Systems 34:7082–7094
  • Li et al (2019)

    Li Y, Wang H, Wei T, et al (2019) Towards automated semi-supervised learning. In: Proceedings of the AAAI Conference on Artificial Intelligence, pp 4237–4244

  • Liu et al (2019) Liu Z, Miao Z, Zhan X, et al (2019) Large-scale long-tailed recognition in an open world. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp 2537–2546
  • Van der Maaten and Hinton (2008) Van der Maaten L, Hinton G (2008) Visualizing data using t-sne. Journal of Machine Learning Research 9(11)
  • Menon et al (2020) Menon AK, Jayasumana S, Rawat AS, et al (2020) Long-tail learning via logit adjustment. In: International Conference on Learning Representations
  • Miyato et al (2018) Miyato T, Maeda Si, Koyama M, et al (2018) Virtual adversarial training: a regularization method for supervised and semi-supervised learning. IEEE Transactions on Pattern Analysis and Machine Intelligence 41(8):1979–1993
  • Ren et al (2020) Ren J, Yu C, Ma X, et al (2020) Balanced meta-softmax for long-tailed visual recognition. Advances in Neural Information Processing Systems 33:4175–4186
  • Sohn et al (2020) Sohn K, Berthelot D, Carlini N, et al (2020) Fixmatch: Simplifying semi-supervised learning with consistency and confidence. Advances in Neural Information Processing Systems 33:596–608
  • Tarvainen and Valpola (2017)

    Tarvainen A, Valpola H (2017) Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results. Advances in Neural Information Processing Systems 30:1195–1204

  • Wang and Geng (2019) Wang J, Geng X (2019) Classification with label distribution learning. In: International Joint Conference on Artificial Intelligence, pp 3712–3718
  • Wei et al (2021a) Wei C, Sohn K, Mellina C, et al (2021a) Crest: A class-rebalancing self-training framework for imbalanced semi-supervised learning. In: Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp 10,857–10,866
  • Wei and Li (2019) Wei T, Li YF (2019) Does tail label help for large-scale multi-label learning? IEEE Transactions on Neural Networks and Learning Systems 31(7):2315–2324
  • Wei et al (2021b) Wei T, Shi J, Tu W, et al (2021b) Robust long-tailed learning under label noise. CoRR abs/2108.11569
  • Wei et al (2022) Wei T, Shi J, Li Y, et al (2022) Prototypical classifier for robust class-imbalanced learning. In: Proceedings of the Pacific-Asia Conference on Knowledge Discovery and Data Mining, pp 44–57
  • Xiang et al (2020) Xiang L, Ding G, Han J (2020) Learning from multiple experts: Self-paced knowledge distillation for long-tailed classification. In: European Conference on Computer Vision, pp 247–263
  • Zagoruyko and Komodakis (2016) Zagoruyko S, Komodakis N (2016) Wide residual networks. In: The British Machine Vision Conference
  • Zhou et al (2021) Zhou Z, Guo LZ, Cheng Z, et al (2021) Step: Out-of-distribution detection in the presence of limited in-distribution labeled data. Advances in Neural Information Processing Systems 34:29,168–29,180
  • Zhu et al (2022) Zhu B, Niu Y, Hua XS, et al (2022) Cross-domain empirical risk minimization for unbiased long-tailed classification. In: Proceedings of the AAAI Conference on Artificial Intelligence