GDW-NIPS2021
This repository is the official implementation of Generalized Data Weighting via Class-level Gradient Manipulation (NeurIPS 2021)(http://arxiv.org/abs/2111.00056).
view repo
Label noise and class imbalance are two major issues coexisting in real-world datasets. To alleviate the two issues, state-of-the-art methods reweight each instance by leveraging a small amount of clean and unbiased data. Yet, these methods overlook class-level information within each instance, which can be further utilized to improve performance. To this end, in this paper, we propose Generalized Data Weighting (GDW) to simultaneously mitigate label noise and class imbalance by manipulating gradients at the class level. To be specific, GDW unrolls the loss gradient to class-level gradients by the chain rule and reweights the flow of each gradient separately. In this way, GDW achieves remarkable performance improvement on both issues. Aside from the performance gain, GDW efficiently obtains class-level weights without introducing any extra computational cost compared with instance weighting methods. Specifically, GDW performs a gradient descent step on class-level weights, which only relies on intermediate gradients. Extensive experiments in various settings verify the effectiveness of GDW. For example, GDW outperforms state-of-the-art methods by 2.56% under the 60% uniform noise setting in CIFAR10. Our code is available at https://github.com/GGchen1997/GDW-NIPS2021.
READ FULL TEXT VIEW PDFThis repository is the official implementation of Generalized Data Weighting via Class-level Gradient Manipulation (NeurIPS 2021)(http://arxiv.org/abs/2111.00056).
Real-world classification datasets often suffer from two issues, i.e., label noise songLearningNoisyLabels2021 and class imbalance heLearningImbalancedData2009 . On the one hand, label noise often results from the limitation of data generation, e.g., sensor errors elhadySystematicSurveySensor2018a and mislabeling from crowdsourcing workers tongxiaoLearningMassiveNoisy2015 . Label noise misleads the training process of DNNs and degrades the model performance in various aspects alganLabelNoiseTypes2020b ; zhuClassNoiseVs2004a ; frenayClassificationPresenceLabel2014a . On the other hand, imbalanced datasets are either naturally long-tailed zhaoLongTailDistributionsUnsupervised2012a ; vanhornDevilTailsFinegrained2017a or biased from the real-world distribution due to imperfect data collection pavonAssessingImpactClassImbalanced2011a ; patelReviewClassificationImbalanced2020a . Training with imbalanced datasets usually results in poor classification performance on weakly represented classes dongClassRectificationHard2017a ; cuiClassBalancedLossBased2019 ; sinhaClassWiseDifficultyBalancedLoss2021a . Even worse, these two issues often coexist in real-world datasets johnsonSurveyDeepLearning2019a .
To prevent the model from memorizing noisy information, many important works have been proposed, including label smoothing szegedyRethinkingInceptionArchitecture2016a , noise adaptation goldbergerTrainingDeepNeuralnetworks2017 , importance weighting liuClassificationNoisyLabels2014 , GLC hendrycksUsingTrustedData2018 , and Co-teach hanCoteachingRobustTraining2018a . Meanwhile, dongClassRectificationHard2017a ; cuiClassBalancedLossBased2019 ; sinhaClassWiseDifficultyBalancedLoss2021a ; linFocalLossDense2020 propose effective methods to tackle class imbalance. However, these methods inevitably introduce hyper-parameters (e.g., the weighting factor in cuiClassBalancedLossBased2019 and the focusing parameter in linFocalLossDense2020 ), compounding real-world deployment.
Inspired by recent advances in meta-learning, some works renLearningReweightExamples2018 ; shuMetaWeightNetLearningExplicit2019a ; huLearningDataManipulation2019a ; wangOptimizingDataUsage2020b propose to solve both issues by leveraging a clean and unbiased meta set. These methods treat instance weights as hyper-parameters and dynamically update these weights to circumvent hyper-parameter tuning. Specifically, MWNet shuMetaWeightNetLearningExplicit2019a adopts an MLP with the instance loss as input and the instance weight as output. Due to the MLP, MWNet has better scalability on large datasets compared with INSW huLearningDataManipulation2019a
which assigns each instance with a learnable weight. Although these methods can handle label noise and class imbalance to some extent, they cannot fully utilize class-level information within each instance, resulting in the potential loss of useful information. For example, in a three-class classification task, every instance has three logits. As shown in Figure
1, every logit corresponds to a class-level gradient flow which stems from the loss function and back-propagates. These gradient flows represent three kinds of information: "not cat", "dog", and "not bird". Instance weighting methods
shuMetaWeightNetLearningExplicit2019a ; renLearningReweightExamples2018 alleviate label noise by downweighting all the gradient flows of the instance, which discards three kinds of information simultaneously. Yet, downweighting the "not bird" gradient flow is a waste of information. Similarly, in class imbalance scenarios, different gradient flows represent different class-level information.Therefore, it is necessary to reweight instances at the class level for better information usage.
To this end, we propose Generalized Data Weighting (GDW) to tackle label noise and class imbalance by class-level gradient manipulation. Firstly, we introduce class-level weights to represent the importance of different gradient flows and manipulate the gradient flows with these class-level weights. Secondly, we impose a zero-mean constraint on class-level weights for stable training. Thirdly, to efficiently obtain class-level weights, we develop a two-stage weight generation scheme embedded in bi-level optimization. As a side note, the instance weighting methods renLearningReweightExamples2018 ; shuMetaWeightNetLearningExplicit2019a ; huLearningDataManipulation2019a ; wangOptimizingDataUsage2020b can be considered special cases of GDW when class-level weights within any instance are the same. In this way, GDW achieves impressive performance improvement in various settings.
To sum up, our contribution is two-fold:
For better information utilization, we propose GDW, a generalized data weighting method, which better handles label noise and class imbalance. To the best of our knowledge, we are the first to propose single-label class-level weighting on gradient flows.
To obtain class-level weights efficiently, we design a two-stage scheme embedded in a bi-level optimization framework, which does not introduce any extra computational cost. To be specific, during the back-propagation we store intermediate gradients, with which we update class-level weights via a gradient descent step.
Label noise is a common problem in classification tasks zhuClassNoiseVs2004a ; frenayClassificationPresenceLabel2014a ; alganLabelNoiseTypes2020b . To avoid overfitting to label noise, szegedyRethinkingInceptionArchitecture2016a propose label smoothing to regularize the model. goldbergerTrainingDeepNeuralnetworks2017 ; vahdatRobustnessLabelNoise2017a form different models to indicate the relation between noisy instances and clean instances. liuClassificationNoisyLabels2014 estimate an importance weight for each instance to represent its value to the model. hanCoteachingRobustTraining2018a train two models simultaneously and let them teach each other in every mini-batch. However, without a clean dataset, these methods cannot handle severe noise renLearningReweightExamples2018 . hendrycksUsingTrustedData2018 correct the prediction of the model by estimating the label corruption matrix via a clean validation set, but this matrix is the same across all instances. Instead, our method generates dynamic class-level weights for every instance to improve training.
Many important works have been proposed to handle class imbalance elkanFoundationsCostsensitiveLearning2001 ; chawlaSMOTESyntheticMinority2002a ; anandApproachClassificationHighly2010a ; dongClassRectificationHard2017a ; khanCostSensitiveLearningDeep2018 ; cuiClassBalancedLossBased2019 ; kangDecouplingRepresentationClassifier2019 ; linFocalLossDense2020 ; sinhaClassWiseDifficultyBalancedLoss2021a . chawlaSMOTESyntheticMinority2002a ; anandApproachClassificationHighly2010a propose to over-sample the minority class and under-sample the majority class. elkanFoundationsCostsensitiveLearning2001 ; khanCostSensitiveLearningDeep2018 learn a class-dependent cost matrix to obtain robust representations for both majority and minority classes. dongClassRectificationHard2017a ; cuiClassBalancedLossBased2019 ; linFocalLossDense2020 ; sinhaClassWiseDifficultyBalancedLoss2021a design a reweighting scheme to rebalance the loss for each class. These methods are quite effective, whereas they need to manually choose loss functions or hyper-parameters. liuLargeScaleLongTailedRecognition2019 ; wangLongtailedRecognitionRouting2020 manipulate the feature space to handle class imbalance while introducing extra model parameters. kangDecouplingRepresentationClassifier2019
decouple representation learning and classifier learning on long-tailed datasets, but with extra hyper-parameter tuning. In contrast, meta-learning methods view instance weights as hyper-parameters and dynamically update them via a meta set to avoid hyper-parameter tuning.
With recent development in meta-learning lakeHumanlevelConceptLearning2015a ; franceschiBilevelProgrammingHyperparameter2018a ; liuDARTSDifferentiableArchitecture2018 , many important methods have been proposed to handle label noise and class imbalance via a meta set wangLearningModelTail2017 ; jiangMentorNetLearningDataDriven2018a ; renLearningReweightExamples2018 ; liLearningLearnNoisy2019a ; shuMetaWeightNetLearningExplicit2019a ; huLearningDataManipulation2019a ; wangOptimizingDataUsage2020b ; alganMetaSoftLabel2021a . jiangMentorNetLearningDataDriven2018a propose MentorNet to provide a data-driven curriculum for the base network to focus on correct instances. To distill effective supervision, zhangDistillingEffectiveSupervision2020a estimate pseudo labels for noisy instances with a meta set. To provide dynamic regularization, vyasLearningSoftLabels2020b ; alganMetaSoftLabel2021a treat labels as learnable parameters and adapt them to the model’s state. Although these methods can tackle label noise, they introduce huge amounts of learnable parameters and thus cannot scale to large datasets. To alleviate class imbalance, wangLearningModelTail2017 describe a method to learn from long-tailed datasets. Specifically, wangLearningModelTail2017
propose to encode meta-knowledge into a meta-network and model the tail classes by transfer learning.
Focal linFocalLossDense2020 | Balanced cuiClassBalancedLossBased2019 | Co-teaching hanCoteachingRobustTraining2018a | GLC hendrycksUsingTrustedData2018 | L2RW renLearningReweightExamples2018 | INSW huLearningDataManipulation2019a | MWNet shuMetaWeightNetLearningExplicit2019a | Soft-label vyasLearningSoftLabels2020b | Gen-label alganMetaSoftLabel2021a | GDW | |
---|---|---|---|---|---|---|---|---|---|---|
Noise | ||||||||||
Imbalance | ||||||||||
Class-level | ||||||||||
Scalability |
Furthermore, many meta-learning methods propose to mitigate the two issues by reweighting every instance renLearningReweightExamples2018 ; saxenaDataParametersNew2019a ; shuMetaWeightNetLearningExplicit2019a ; huLearningDataManipulation2019a ; wangOptimizingDataUsage2020b . saxenaDataParametersNew2019a equip each instance and each class with a learnable parameter to govern their importance. By leveraging a meta set, renLearningReweightExamples2018 ; shuMetaWeightNetLearningExplicit2019a ; huLearningDataManipulation2019a ; wangOptimizingDataUsage2020b learn instance weights and model parameters via bi-level optimization to tackle label noise and class imbalance. renLearningReweightExamples2018 assign weights to training instances only based on their gradient directions. Furthermore, huLearningDataManipulation2019a
combine reinforce learning and meta-learning, and treats instance weights as rewards for optimization. However, since each instance is directly assigned with a learnable weight, INSW can not scale to large datasets. Meanwhile,
shuMetaWeightNetLearningExplicit2019a ; wangOptimizingDataUsage2020b adopt a weighting network to output weights for instances and use bi-level optimization to jointly update the weighting network parameters and model parameters. Although these methods handle label noise and class imbalance by reweighting instances, a scalar weight for every instance cannot capture class-level information, as shown in Figure 1. Therefore, we introduce class-level weights for different gradient flows and adjust them to better utilize class-level information.We show the differences between GDW and other related methods in Table 1.
In most classification tasks, there is a training set and we assume there is also a clean unbiased meta set . We aim to alleviate label noise and class imbalance in with . The model parameters are denoted as , and the number of classes is denoted as .
To utilize class-level information, we learn a class-level weight for every gradient flow instead of a scalar weight for all gradient flows in shuMetaWeightNetLearningExplicit2019a . Denote as the loss of any instance. Applying the chain rule, we unroll the gradient of w.r.t. as
(1) |
where
represents the predicted logit vector of the instance. We introduce class-level weights
and denote the component of as . To indicate the importance of every gradient flow, we perform an element-wise product on with . After this manipulation, the gradient becomes(2) |
where denotes the element-wise product of two vectors. Note that represents the importance of the gradient flow. Obviously, instance weighting is a special case of GDW when elements of are the same. Most classification tasks howardMobileNetsEfficientConvolutional2017a ; qinRethinkingSoftmaxCrossEntropy2020a ; zhaoBetterAccuracyefficiencyTradeoffs2021a adopt the Softmax-CrossEntropy loss. In this case, we have , where
denotes the probability vector output by
softmax and denotes the one-hot label of the instance (see Appendix A for details).As shown in Figure 1, for a noisy instance (e.g., cat mislabeled as "dog"), instance weighting methods assign a low scalar weight to all gradient flows of the instance. Instead, GDW assigns class-level weights to different gradient flows by leveraging the meta set. In other words, GDW tries to downweight the gradient flows for "dog" and "not cat", and upweight the gradient flow for "not bird". Similarly, in imbalance settings, different gradient flows have different class-level information. Thus GDW can also better handle class imbalance by adjusting the importance of different gradient flows.
To retain the Softmax-CrossEntropy loss structure, i.e. the form, after the manipulation, we impose a zero-mean constraint on . That is, we analyze the element of (see Appendix B.1 for details):
(3) |
where is the weighted probability, and denotes the class-level weight at the target (label) position. We observe that the first term in Eq. (3) satisfies the structure of the gradient of the Softmax-CrossEntropy loss, and thus propose to eliminate the second term which messes the structure. Specifically, we let
(4) |
where is the probability of the target class. Note that , and thus we have
(5) |
This restricts the mean of to be zero. Therefore, we name this constraint as the zero-mean constraint. With this, we have
(6) |
Eq. (6) indicates that adjust the gradients in two levels, i.e., instance level and class level. Namely, the scalar acts as the instance-level weight in previous instance weighting methods renLearningReweightExamples2018 ; shuMetaWeightNetLearningExplicit2019a ; huLearningDataManipulation2019a ; wangOptimizingDataUsage2020b , and the ’s are the class-level weights manipulating gradient flows by adjusting the probability from to .
In this subsection, we first illustrate the three-step bi-level optimization framework in shuMetaWeightNetLearningExplicit2019a . Furthermore, we embed a two-stage scheme in the bi-level optimization framework to efficiently obtain class-level weights, with which we manipulate gradient flows and optimize model parameters.
Three-step Bi-level Optimization. Generally, the goal of classification tasks is to obtain the optimal model parameters by minimizing the average loss on , denoted as . As an instance weighting method, shuMetaWeightNetLearningExplicit2019a adopt a three-layer MLP parameterized by as the weighting network and take the loss of the instance as input and output a scalar weight . Then is optimized by minimizing the instance-level weighted training loss:
(7) |
To obtain the optimal , they propose to use a meta set as meta-knowledge and minimize the meta-loss to obtain :
(8) |
Since the optimization for and is nested, they adopt an online strategy to update and with a three-step optimization loop for efficiency. Denote the two sets of parameters at the loop as and respectively, and then the three-step loop is formulated as:
[label=Step 0]
Update to via an SGD step on a mini-batch training set by Eq. (7).
With , update to via an SGD step on a mini-batch meta set by Eq. (8).
With , update to via an SGD step on the same mini-batch training set by Eq. (7).
Instance weights in Step 3 are better than those in Step 1, and thus are used to update .
Two-stage Weight Generation. To guarantee scalability, we apply the same weighting network in shuMetaWeightNetLearningExplicit2019a to obtain weights. To efficiently train and , we also adopt the three-step bi-level optimization framework. Moreover, we propose an efficient two-stage scheme embedded in Step 1-3 to generate class-level weights. This process does not introduce any extra computational cost compared to shuMetaWeightNetLearningExplicit2019a . We keep the notations of and unchanged.
The first stage is embedded in Step 1. Explicitly, we obtain the first-stage class-level weights , by cloning the output of the weighting network for times. Then we leverage the cloned weights to manipulate gradients and update with a mini-batch of training instances:
(9) |
where is the mini-batch size, is the learning rate of , and is the gradient manipulation operation defined in Eq. (2).
The second stage is embedded in Step 2 and Step 3. Specifically in Step 2, GDW optimizes with a mini-batch meta set:
(10) |
where is the mini-batch size and is the learning rate of . During the back-propagation in updating , GDW generates the second-stage weights using the intermediate gradients on . Precisely,
(11) |
where represents and denotes the clip parameter. Then we impose the zero-mean constraint proposed in Eq. (4) on , which is later used in Step 3 to update . Note that the two-stage weight generation scheme does not introduce any extra computational cost compared to MWNet because this generation process only utilizes the intermediate gradients during the back-propagation. In Step 3, we use to manipulate gradients and update the model parameters :
(12) |
The only difference between Step 1 and Step 3 is that we use instead of the cloned output of the weighting network to optimize . Since we only introduce as extra learnable parameters, GDW can scale to large datasets. We summarize GDW in Algorithm 1. Moreover, we visualize the two-stage weight generation process in Figure 2 for better demonstration.
We conduct extensive experiments on classification tasks to examine the performance of GDW. We compare GDW with other methods in the label noise setting and class imbalance setting in Section 4.1 and Section 4.2, respectively. Next, we perform experiments on the real-world dataset Clothing1M tongxiaoLearningMassiveNoisy2015 in Section 4.3. We conduct further experiments to verify the performance of GDW in the mixed setting, i.e. the coexistence of label noise and class imbalance (see Appendix F for details).
Setup. Following shuMetaWeightNetLearningExplicit2019a , we study two settings of label noise: a) Uniform noise: every instance’s label uniformly flips to other class labels with probability ; b) Flip noise: each class randomly flips to another class with probability . Note that the probability represents the noise ratio. We randomly select clean images per class from CIFAR10 krizhevskyLearningMultipleLayers2009 as the meta set ( images in total). Similarly, we select a total of images from CIFAR100 as its meta set. We use ResNet-32 heDeepResidualLearning2016 as the classifier model.
Comparison methods. We mainly compare GDW with meta-learning methods: 1) L2RW renLearningReweightExamples2018 , which assigns weights to instances based on gradient directions; 2) INSW huLearningDataManipulation2019a , which derives instance weights adaptively from the meta set; 3) MWNet shuMetaWeightNetLearningExplicit2019a ; 4) Soft-label vyasLearningSoftLabels2020b , which learns a label smoothing parameter for every instance; 5) Gen-label alganMetaSoftLabel2021a , which generates a meta-soft-label for every instance. We also compare GDW with some traditional methods: 6) BaseModel, which trains ResNet-32 on the noisy training set; 7) Fine-tuning, which uses the meta set to fine-tune the trained model in BaseModel; 8) Co-teaching hanCoteachingRobustTraining2018a ; 9) GLC hendrycksUsingTrustedData2018 .
Training. Most of our training settings follow shuMetaWeightNetLearningExplicit2019a and we use the cosine learning rate decay schedule loshchilovSGDRStochasticGradient2016 for a total of epochs for all methods. See Appendix C for details.
Dataset | CIFAR10 | CIFAR100 | ||||
---|---|---|---|---|---|---|
BaseModel | ||||||
Fine-tuning | ||||||
Co-teaching | ||||||
GLC | ||||||
L2RW | ||||||
INSW | ||||||
MWNet | 92.95 0.33 | |||||
Soft-label | ||||||
Gen-label | ||||||
GDW | 92.94 0.15 | 88.14 0.35 | 84.11 0.21 | 70.65 0.52 | 59.82 1.62 | 53.33 3.70 |
Dataset | CIFAR10 | CIFAR100 | ||||
---|---|---|---|---|---|---|
BaseModel | ||||||
Fine-tuning | ||||||
Co-teaching | ||||||
GLC | 89.74 0.19 | 63.11 0.93 | ||||
L2RW | ||||||
INSW | ||||||
MWNet | 92.95 0.33 | |||||
Soft-label | ||||||
Gen-label | ||||||
GDW | 92.94 0.15 | 91.05 0.26 | 70.65 0.52 | 65.41 0.75 |
Analysis.
For all experiments, we report the mean and standard deviation over
runs in Table 2 and Table 3, where the best results are in bold and the second-best results are marked by underlines. First, we can observe that GDW outperforms nearly all the competing methods in all noise settings except for the flip noise setting. Under this setting, GLC estimates the label corruption matrix well and thus performs the best, whereas the flip noise assumption scarcely holds in real-world scenarios. Note that GLC also performs much better than MWNet under the flip noise setting as reported in shuMetaWeightNetLearningExplicit2019a . Besides, under all noise settings, GDW has a consistent performance gain compared with MWNet, which aligns with our motivation in Figure 1. Furthermore, as the ratio increases from to in the uniform noise setting, the gap between GDW and MWNet increases from to in CIFAR10 and to in CIFAR100. Even under uniform noise, GDW still has low test errors in both datasets and achieves more than gain in CIFAR10 and gain in CIFAR100 compared with the second-best method. Last but not least, GDW outperforms Soft-label and Gen-label in all settings. One possible reason is that manipulating gradient flows is a more direct way to capture class-level information than learning labels.In Figure 4, we show the distribution of class-level target weight () on clean and noisy instances in one epoch under the CIFAR10 40% uniform noise setting. We observe that of most clean instances are larger than that of most noisy instances, which indicates that can distinguish between clean instances and noisy instances. This is consistent with Eq. (3) that serves as the instance weight.
To better understand the changing trend of non-target class-level weights, we visualize the ratio of increased weights in one epoch in Figure 6 under the CIFAR10 40% uniform noise setting. Specifically, there are three categories: non-target weights on clean instances (), true target weights on noisy instances () and non-target (excluding true targets) weights on noisy instances (). Formally, "target weight" means the class-level weight on the label position. "true-target weight" means the class-level weight on the true label position, which are only applicable for noisy instances. "non-target weight" means the class-level weight except the label position and the true label position. For example, as shown in Figure 1 where a cat is mislabeled as "dog", the corresponding meanings of the notations are as follows: 1) means ("dog" is the target); 2) means ("cat" is the ture target); 3) means ("bird" is one of the non-targets). For a correctly labeled cat, the corresponding meanings are: 1) means ("cat" is both the target and the ture target); 2) means and ("dog" and "bird" are both non-targets).
Note that in Figure 1, represents the importance of the "not cat" gradient flow and represents the importance of the "not bird" gradient flow. If the cat image in Figure 1 is correctly labeled as "cat", then the two non-target weights are used to represent the importance of the "not dog" and the "not bird" gradient flows, respectively. In one epoch, we calculate the ratios of the number of increased , and to the number of all corresponding weights. and are expected to increase since their gradient flows contain valuable information, whereas is expected to decrease because the "not cat" gradient flow contains harmful information. Figure 6 aligns perfectly with our expectation. Note that the lines of and nearly coincide with each other and fluctuate around . This means non-target weights on clean instances and noisy instances share the same changing pattern, i.e., around of and increase. Besides, less than of increase and thus more than decrease, which means the gradient flows of contain much harmful information.
In Figure 4, we show the change of class-level weights in an iteration for a noisy instance, i.e., a cat image mislabeled as "dog". The gradient flows of "not cat" and "dog" contain harmful information and thus are downweighted by GDW. In addition, GDW upweights the valuable "not bird" gradient flow from to . By contrast, unable to capture class-level information, MWNet downweights all gradient flows from to , which leads to information loss on the "not bird" gradient flow.
Training without the zero-mean constraint. We have also tried training without the zero-mean constraint in Section 3.3 and got poor results (see Appendix B.2 for details). Denote the true target as and one of the non-target labels as (). Note that the gradient can be unrolled as (see Appendix B.2 for details):
(13) |
If is positive and the learning rate is small enough, contributes to the decrease of the true target logit after a gradient descent step. If negative, contributes to the increase of the non-target logit . Therefore, without the zero-mean constraint, the second term in Eq. (13) may hurt the performance of the model regardless of the sign of . Similarly, training without the constraint results in poor performance in other settings. Hence we omit those results in the following subsections.
Dataset | CIFAR10 | CIFAR100 | ||||
---|---|---|---|---|---|---|
BaseModel | ||||||
Fine-tuning | ||||||
Focal | ||||||
Balanced | ||||||
L2RW | ||||||
INSW | ||||||
MWNet | 92.95 0.33 | |||||
GDW | 92.94 0.15 | 86.77 0.55 | 71.31 1.03 | 70.65 0.52 | 56.78 0.52 | 37.94 1.58 |
Setup and comparison methods. The imbalance factor of a dataset is defined as the number of instances in the smallest class divided by that of the largest shuMetaWeightNetLearningExplicit2019a
. Long-Tailed CIFAR
krizhevskyLearningMultipleLayers2009 are created by reducing the number of training instances per class according to an exponential function , where is the class index (0-indexed) and is the original number of training instances. Comparison methods include: 1) L2RW renLearningReweightExamples2018 ; 2) INSW huLearningDataManipulation2019a ; 3) MWNet shuMetaWeightNetLearningExplicit2019a ; 4) BaseModel; 5) Fine-tuning; 6) Balanced cuiClassBalancedLossBased2019 ; 7) Focal linFocalLossDense2020 .Analysis. As shown in Table 4, GDW performs best in nearly all settings and exceeds MWNet by when the imbalance ratio is in CIFAR10. Besides, INSW achieves competitive performance at the cost of introducing a huge amount of learnable parameters (equal to the training dataset size ). Furthermore, we find that BaseModel achieves competitive performance, but fine-tuning on the meta set hurts the model’s performance. We have tried different learning rates from to for fine-tuning, but the results are similar. One explanation is that the balanced meta set worsens the model learned from the imbalanced training set. These results align with the experimental results in huLearningDataManipulation2019a which also deals with class imbalance.
Denote the smallest class as and the second smallest class as in Long-Tailed CIFAR10 with . Recall that denotes the class-level weight. For all instances in an epoch, we calculate the ratio of the number of increased to the number of all , and then visualize the ratio trend in Figure 6. Since is the smallest class, instance weighting methods upweight both and on a instance. Yet in Figure 6, less than of increase and thus more than decrease. This can be explained as follows. There are two kinds of information in the long-tailed dataset regarded to : "" and "not ". Since belongs to the minority class, the dataset is biased towards the "not " information. Because represents the importance of "not ", a smaller weakens the "not " information. As a result, decreased achieves a balance between two kinds of information: "" and "not ", thus better handling class imbalance at the class level. We have conducted further experiments on imbalanced settings to verify the effectiveness of GDW and see Appendix D for details.
Setup and training. The Clothing1M dataset contains one million images from fourteen classes collected from the web tongxiaoLearningMassiveNoisy2015
. Labels are constructed from surrounding texts of images and thus contain some errors. We use the ResNet-18 model pre-trained on ImageNet
dengImageNetLargescaleHierarchical2009 as the classifier. The comparison methods are the same as those in the label noise setting since the main issue of Clothing1M is label noise tongxiaoLearningMassiveNoisy2015 . All methods are trained for epochs via SGD with a momentum, a initial learning rate, a weight decay, and a batchsize. See Appendix E for details.Method | BaseModel | Fine-tuning | Co-teaching | GLC | L2RW | INSW | MWNet | Soft-label | Gen-label | GDW |
---|---|---|---|---|---|---|---|---|---|---|
Accuracy() | 69.39 |
Analysis. As shown in Table 5, GDW achieves the best performance among all the comparison methods and outperforms MWNet by . In contrast to unsatisfying results in previous settings, L2RW performs quite well in this setting. One possible explanation is that, compared with INSW and MWNet which update weights iteratively, L2RW obtains instance weights only based on current gradients. As a result, L2RW can more quickly adapt to the model’s state, but meanwhile suffers from unstable weights shuMetaWeightNetLearningExplicit2019a . In previous settings, we train models from scratch, which need stable weights to stabilize training. Therefore, INSW and MWNet generally achieve better performance than L2RW. Whereas in this setting, we use the pre-trained ResNet-18 model which is already stable enough. Thus L2RW performs better than INSW and MWNet.
Many instance weighting methods have recently been proposed to tackle label noise and class imbalance, but they cannot capture class-level information. For better information use when handling the two issues, we propose GDW to generalize data weighting from instance level to class level by reweighting gradient flows. Besides, to efficiently obtain class-level weights, we design a two-stage weight generation scheme which is embedded in a three-step bi-level optimization framework and leverages intermediate gradients to update class-level weights via a gradient descent step. In this way, GDW achieves remarkable performance improvement in various settings.
The limitations of GDW are two-fold. Firstly, the gradient manipulation is only applicable to single-label classification tasks. When applied to multi-label tasks, the formulation of gradient manipulation need some modifications. Secondly, GDW does not outperform comparison methods by a large margin in class imbalance settings despite the potential effectiveness analyzed in Section 4.2. One possible explanation is that better information utilization may not result in performance gain which also depends on various other factors.
We thank Prof. Hao Zhang from Tsinghua University for helpful suggestions. This research was supported in part by the MSR-Mila collaboration funding. Besides, this research was empowered in part by the computational support provided by Compute Canada (www.computecanada.ca).
2015 IEEE Conference on Computer Vision and Pattern Recognition (CVPR)
, pages 2691–2699, Boston, MA, USA, June 2015. IEEE.IEEE Transactions on Neural Networks and Learning Systems
, 25(5):845–869, May 2014.Long-Tail Distributions and Unsupervised Learning of Morphology.
In Proceedings of COLING 2012, pages 3121–3136, Mumbai, India, December 2012. The COLING 2012 Organizing Committee.Class Rectification Hard Mining for Imbalanced Deep Learning.
In 2017 IEEE International Conference on Computer Vision (ICCV), pages 1869–1878, Venice, October 2017. IEEE.Proceedings of the 35th International Conference on Machine Learning
, pages 4334–4343. PMLR, July 2018.Cost-Sensitive Learning of Deep Feature Representations From Imbalanced Data.
IEEE Transactions on Neural Networks and Learning Systems, 29(8):3573–3587, August 2018.Bilevel Programming for Hyperparameter Optimization and Meta-Learning.
In Proceedings of the 35th International Conference on Machine Learning, pages 1568–1577. PMLR, July 2018.SGDR: Stochastic Gradient Descent with Warm Restarts.
International Conference on Learning Representations, November 2016.Denote the logit vector as , we have
(14) | ||||
(15) |
For the target (label) position we have and
(16) |
For any other position (), we have and
(17) |
Therefore, we can conclude that = .
(18) | ||||
(19) | ||||
If , the second term of (19) becomes , therefore can be rewritten as | ||||
(20) | ||||
(21) |
(22) |
Without zero-mean constraint, the training becomes unstable. We plot the training curve of the CIFAR10 40% uniform noise setting in Figure 7.
Following the training setting of shuMetaWeightNetLearningExplicit2019a , the classifier network is trained with SGD with a weight decay e-, an initial learning rate of e- and a mini-batch size of for all methods. We use the cosine learning rate decay schedule loshchilovSGDRStochasticGradient2016 for a total of 80 epochs. The MLP weighting network is trained with Adam kingmaAdamMethodStochastic2015 with a fixed learning rate 1e-3 and a weight decay 1e-4. For GLC, we first train 40 epochs to estimate the label corruption matrix and then train another 40 epochs to evaluate its performance. Since Co-teach uses two models, each model is trained for 40 epochs for a fair comparison. We repeat every experiment 5 times with different random seeds (seed=1, 10, 100, 1000, 10000, respectively) for network initialization and label noise generation. We report the average test accuracy over the last 5 epochs as the model’s final performance. We use one V100 GPU for all the experiments.
We conduct one more experiment under the imbalance setting to better verify the interpretability of GDW. As shown in Table 6, we report the ratio of the number of increased after gradient update on instances in one epoch, where is the largest class and is the smallest class.
Note that on contains the "is " information in the dataset. As a result, on should be large for small classes and small for large classes. As shown above, the ratio of increased on (the diagonal elements) increases from to as increases from to .
On the other hand, on () contains the "not " information in the dataset. If is a large class, on () should be large and vice versa. For (), the ratio of increased on () are generally larger than , and for (), the ratio of increased on () are generally less than . These results align with our analysis on the interpretable information of gradient flows.
weight/class | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|
0.036 | 0.968 | 0.973 | 0.972 | 0.965 | 0.974 | 0.972 | 0.976 | 0.956 | 0.973 | |
0.887 | 0.095 | 0.912 | 0.929 | 0.907 | 0.927 | 0.911 | 0.922 | 0.910 | 0.920 | |
0.848 | 0.844 | 0.141 | 0.839 | 0.822 | 0.845 | 0.818 | 0.847 | 0.829 | 0.802 | |
0.585 | 0.608 | 0.552 | 0.405 | 0.569 | 0.541 | 0.561 | 0.559 | 0.617 | 0.596 | |
0.474 | 0.521 | 0.420 | 0.460 | 0.509 | 0.455 | 0.456 | 0.482 | 0.467 | 0.512 | |
0.291 | 0.261 | 0.288 | 0.252 | 0.309 | 0.701 | 0.303 | 0.267 | 0.297 | 0.257 | |
0.199 | 0.189 | 0.169 | 0.198 | 0.196 | 0.222 | 0.778 | 0.195 | 0.207 | 0.182 | |
0.117 | 0.117 | 0.105 | 0.084 | 0.115 | 0.079 | 0.126 | 0.920 | 0.133 | 0.090 | |
0.115 | 0.124 | 0.178 | 0.185 | 0.184 | 0.174 | 0.191 | 0.181 | 0.862 | 0.137 | |
0.043 | 0.050 | 0.064 | 0.061 | 0.074 | 0.062 | 0.097 | 0.069 | 0.040 | 0.935 |
Method | L2RW | INSW | MW-Net | GDW |
---|---|---|---|---|
Accuracy (%) | 19.17 |
We have conducted experiments on the Places-LT dataset liuLargeScaleLongTailedRecognition2019 and compared GDW with other meta-learning-based methods. For all methods, the weight decay is set to and the batchsize is set to . We adopt a initial learning rate and a cosine learning rate decay policy for epochs. The weight decay is set to . The backbone network is ResNet18 and we use the ImageNet pre-trained model for initialization.
As shown in Table 7, GDW achieves the best performance among all the comparison methods and outperforms MWNet by . This improvement is larger than that of CIFAR100. The reason is that GDW can manipulate class-level information and thus performs better on the dataset with a larger number of classes ( in Places-LT and in CIFAR100). Besides, we can observe that L2RW performs the worst and the reason may be that L2RW suffers from unstable weights shuMetaWeightNetLearningExplicit2019a .
Similar to shuMetaWeightNetLearningExplicit2019a , we use the k validation set as the meta set and the origin test set to evaluate the classifier’s final performance. For GLC, we first train epochs to estimate the label corruption matrix and then train another epochs to evaluate its performance. Since Co-teach uses two models, each model is trained for epochs for a fair comparison.
We conduct further experiments to verify the performance of GDW in the mixed setting, i.e. the coexistence of label noise and class imbalance. Specifically, we compare GDW with the mostly second-best method MW-Net shuMetaWeightNetLearningExplicit2019a under the mixed setting of uniform noise and class imbalance on CIFAR10 and CIFAR100. As shown in Table 8, GDW demonstrates great performance gain over MW-Net, which means GDW can simultaneously better tackle both problems.
Dataset | Noise Ratio | Imb Factor | MW-Net | GDW |
---|---|---|---|---|
CIFAR10 | ||||
CIFAR10 | ||||
CIFAR10 | ||||
CIFAR10 | ||||
CIFAR100 | ||||
CIFAR100 | ||||
CIFAR100 | ||||
CIFAR100 |
For all authors…
Do the main claims made in the abstract and introduction accurately reflect the paper’s contributions and scope?
Did you describe the limitations of your work? The proposed method can only be applied on classification tasks.
Did you discuss any potential negative societal impacts of your work?
Have you read the ethics review guidelines and ensured that your paper conforms to them?
If you ran experiments…
Did you include the code, data, and instructions needed to reproduce the main experimental results (either in the supplemental material or as a URL)? We only use public datasets and the code is in the supplementary materials.
Did you specify all the training details (e.g., data splits, hyperparameters, how they were chosen)? Most of our settings follow shuMetaWeightNetLearningExplicit2019a and other details are in the appendix.
Did you report error bars (e.g., with respect to the random seed after running experiments multiple times)? We repeat all experiments on CIFAR10 and CIFAR100 with five different seeds and the mean and standard deviation are reported. For the Clothing1M dataset, we only run one experiment due to limited resources.
Did you include the total amount of compute and the type of resources used (e.g., type of GPUs, internal cluster, or cloud provider)? We use one V100 GPU. See the appendix for details.
If you are using existing assets (e.g., code, data, models) or curating/releasing new assets…
If your work uses existing assets, did you cite the creators? For dataset, we cite the papers of CIFAR datasets and the Clothing1M dataset. For code, we cite heDeepResidualLearning2016 .
Did you mention the license of the assets?
Did you include any new assets either in the supplemental material or as a URL?
Did you discuss whether and how consent was obtained from people whose data you’re using/curating?
Did you discuss whether the data you are using/curating contains personally identifiable information or offensive content?
If you used crowdsourcing or conducted research with human subjects…
Did you include the full text of instructions given to participants and screenshots, if applicable?
Did you describe any potential participant risks, with links to Institutional Review Board (IRB) approvals, if applicable?
Did you include the estimated hourly wage paid to participants and the total amount spent on participant compensation?